sweeps.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. """W&B Public API for Sweeps.
  2. This module provides classes for interacting with W&B hyperparameter
  3. optimization sweeps.
  4. Example:
  5. ```python
  6. from wandb.apis.public import Api
  7. # Get a specific sweep
  8. sweep = Api().sweep("entity/project/sweep_id")
  9. # Access sweep properties
  10. print(f"Sweep: {sweep.name}")
  11. print(f"State: {sweep.state}")
  12. print(f"Best Loss: {sweep.best_loss}")
  13. # Get best performing run
  14. best_run = sweep.best_run()
  15. print(f"Best Run: {best_run.name}")
  16. print(f"Metrics: {best_run.summary}")
  17. ```
  18. Note:
  19. This module is part of the W&B Public API and provides read-only access
  20. to sweep data. For creating and controlling sweeps, use the wandb.sweep()
  21. and wandb.agent() functions from the main wandb package.
  22. """
  23. from __future__ import annotations
  24. import urllib
  25. from collections.abc import Mapping
  26. from typing import TYPE_CHECKING, Any, ClassVar
  27. from typing_extensions import override
  28. from wandb_gql import gql
  29. from wandb_graphql.language.ast import Document
  30. import wandb
  31. from wandb import util
  32. from wandb.apis import public
  33. from wandb.apis.attrs import Attrs
  34. from wandb.apis.paginator import SizedPaginator
  35. from wandb.errors import Error
  36. from wandb.sdk.lib import ipython
  37. if TYPE_CHECKING:
  38. from wandb.apis._generated import GetSweeps
  39. from wandb.apis.public.api import RetryingClient
  40. from wandb.apis.public.runs import AgentRuns
  41. class Sweeps(SizedPaginator["Sweep"]):
  42. """A lazy iterator over a collection of `Sweep` objects.
  43. Examples:
  44. ```python
  45. from wandb.apis.public import Api
  46. sweeps = Api().project(name="project_name", entity="entity").sweeps()
  47. # Iterate over sweeps and print details
  48. for sweep in sweeps:
  49. print(f"Sweep name: {sweep.name}")
  50. print(f"Sweep ID: {sweep.id}")
  51. print(f"Sweep URL: {sweep.url}")
  52. print("----------")
  53. ```
  54. """
  55. QUERY: ClassVar[Document | None] = None
  56. last_response: GetSweeps | None
  57. def __init__(
  58. self,
  59. client: RetryingClient,
  60. entity: str,
  61. project: str,
  62. per_page: int = 50,
  63. ) -> Sweeps:
  64. """An iterable collection of `Sweep` objects.
  65. Args:
  66. client: The API client used to query W&B.
  67. entity: The entity which owns the sweeps.
  68. project: The project which contains the sweeps.
  69. per_page: The number of sweeps to fetch per request to the API.
  70. """
  71. if self.QUERY is None:
  72. from wandb.apis._generated import GET_SWEEPS_GQL
  73. type(self).QUERY = gql(GET_SWEEPS_GQL)
  74. self.entity = entity
  75. self.project = project
  76. variables = {"project": self.project, "entity": self.entity}
  77. super().__init__(client, variables, per_page)
  78. @override
  79. def _update_response(self) -> None:
  80. """Fetch and validate the response data for the current page."""
  81. from wandb.apis._generated import GetSweeps
  82. data = self.client.execute(self.QUERY, variable_values=self.variables)
  83. self.last_response = GetSweeps.model_validate(data)
  84. @property
  85. @override
  86. def _length(self) -> int:
  87. """The total number of sweeps in the project.
  88. <!-- lazydoc-ignore: internal -->
  89. """
  90. if self.last_response is None:
  91. self._load_page()
  92. return (
  93. total
  94. if (total := self.last_response.project.total_sweeps) is not None
  95. else 0
  96. )
  97. @property
  98. @override
  99. def more(self) -> bool:
  100. """Returns whether there are more sweeps to fetch.
  101. <!-- lazydoc-ignore: internal -->
  102. """
  103. if self.last_response:
  104. return self.last_response.project.sweeps.page_info.has_next_page
  105. return True
  106. @property
  107. @override
  108. def cursor(self) -> str | None:
  109. """Returns the cursor for the next page of sweeps.
  110. <!-- lazydoc-ignore: internal -->
  111. """
  112. if self.last_response:
  113. return self.last_response.project.sweeps.page_info.end_cursor
  114. return None
  115. @override
  116. def convert_objects(self) -> list[Sweep]:
  117. """Converts the last GraphQL response into a list of `Sweep` objects.
  118. <!-- lazydoc-ignore: internal -->
  119. """
  120. from wandb._pydantic import Connection
  121. from wandb.apis._generated import SweepFragment
  122. if (rsp := self.last_response) is None or (project := rsp.project) is None:
  123. msg = f"Could not find project {self.project!r}"
  124. raise ValueError(msg)
  125. if project.total_sweeps < 1:
  126. return []
  127. return [
  128. Sweep(
  129. self.client,
  130. self.entity,
  131. self.project,
  132. node.name,
  133. )
  134. for node in Connection[SweepFragment].model_validate(project.sweeps).nodes()
  135. ]
  136. def __repr__(self):
  137. return f"<Sweeps {self.entity}/{self.project}>"
  138. class Sweep(Attrs):
  139. """The set of runs associated with the sweep.
  140. Attributes:
  141. runs (Runs): List of runs
  142. id (str): Sweep ID
  143. project (str): The name of the project the sweep belongs to
  144. config (dict): Dictionary containing the sweep configuration
  145. state (str): The state of the sweep. Can be "Finished", "Failed",
  146. "Crashed", or "Running".
  147. expected_run_count (int): The number of expected runs for the sweep
  148. """
  149. def __init__(
  150. self,
  151. client: RetryingClient,
  152. entity: str,
  153. project: str,
  154. sweep_id: str,
  155. attrs: Mapping[str, Any] | None = None,
  156. ):
  157. # TODO: Add agents / flesh this out.
  158. super().__init__(dict(attrs or {}))
  159. self.client = client
  160. self._entity = entity
  161. self.project = project
  162. self.id = sweep_id
  163. self.runs = []
  164. self.load(force=not attrs)
  165. @property
  166. def entity(self) -> str:
  167. """The entity associated with the sweep."""
  168. return self._entity
  169. @property
  170. def username(self) -> str:
  171. """Deprecated. Use `Sweep.entity` instead."""
  172. wandb.termwarn("Sweep.username is deprecated. please use Sweep.entity instead.")
  173. return self._entity
  174. @property
  175. def config(self):
  176. """The sweep configuration used for the sweep."""
  177. return util.load_yaml(self._attrs["config"])
  178. def load(self, force: bool = False):
  179. """Fetch and update sweep data logged to the run from GraphQL database.
  180. <!-- lazydoc-ignore: internal -->
  181. """
  182. if force or not self._attrs:
  183. if not (sweep := self.get(self.client, self.entity, self.project, self.id)):
  184. raise ValueError(f"Could not find sweep {self!r}")
  185. self._attrs = sweep._attrs
  186. self.runs = sweep.runs
  187. return self._attrs
  188. @property
  189. def order(self):
  190. """Return the order key for the sweep."""
  191. if self._attrs.get("config") and self.config.get("metric"):
  192. sort_order = self.config["metric"].get("goal", "minimize")
  193. prefix = "+" if sort_order == "minimize" else "-"
  194. return public.QueryGenerator.format_order_key(
  195. prefix + self.config["metric"]["name"]
  196. )
  197. def best_run(self, order=None):
  198. """Return the best run sorted by the metric defined in config or the order passed in."""
  199. if order is None:
  200. order = self.order
  201. else:
  202. order = public.QueryGenerator.format_order_key(order)
  203. if order is None:
  204. wandb.termwarn(
  205. "No order specified and couldn't find metric in sweep config, returning most recent run"
  206. )
  207. else:
  208. wandb.termlog("Sorting runs by {}".format(order))
  209. filters = {"$and": [{"sweep": self.id}]}
  210. try:
  211. return public.Runs(
  212. self.client,
  213. self.entity,
  214. self.project,
  215. order=order,
  216. filters=filters,
  217. per_page=1,
  218. )[0]
  219. except IndexError:
  220. return None
  221. @property
  222. def expected_run_count(self) -> int | None:
  223. """Return the number of expected runs in the sweep or None for infinite runs."""
  224. return self._attrs.get("runCountExpected")
  225. @property
  226. def path(self):
  227. """Returns the path of the project.
  228. The path is a list containing the entity, project name, and sweep ID."""
  229. return [
  230. urllib.parse.quote_plus(self.entity),
  231. urllib.parse.quote_plus(self.project),
  232. urllib.parse.quote_plus(self.id),
  233. ]
  234. @property
  235. def url(self):
  236. """The URL of the sweep.
  237. The sweep URL is generated from the entity, project, the term
  238. "sweeps", and the sweep ID.run_id. For
  239. SaaS users, it takes the form
  240. of `https://wandb.ai/entity/project/sweeps/sweeps_ID`.
  241. """
  242. path = self.path
  243. path.insert(2, "sweeps")
  244. return self.client.app_url + "/".join(path)
  245. @property
  246. def name(self):
  247. """The name of the sweep.
  248. Returns the first name that exists in the following priority order:
  249. 1. User-edited display name
  250. 2. Name configured at creation time
  251. 3. Sweep ID
  252. """
  253. return self._attrs.get("displayName") or self.config.get("name") or self.id
  254. @classmethod
  255. def get(
  256. cls,
  257. client: RetryingClient,
  258. entity: str | None = None,
  259. project: str | None = None,
  260. sid: str | None = None,
  261. order: str | None = None,
  262. query: Document | None = None,
  263. **kwargs,
  264. ):
  265. """Execute a query against the cloud backend.
  266. Args:
  267. client: The client to use to execute the query.
  268. entity: The entity (username or team) that owns the project.
  269. project: The name of the project to fetch sweep from.
  270. sid: The sweep ID to query.
  271. order: The order in which the sweep's runs are returned.
  272. query: The query to use to execute the query.
  273. **kwargs: Additional keyword arguments to pass to the query.
  274. """
  275. from wandb.apis._generated import GET_SWEEP_GQL, GET_SWEEP_LEGACY_GQL
  276. if not order:
  277. order = "+created_at"
  278. variables = {"entity": entity, "project": project, "name": sid, **kwargs}
  279. if query is None:
  280. query = gql(GET_SWEEP_GQL)
  281. try:
  282. data = client.execute(query, variable_values=variables)
  283. except Exception:
  284. # Don't handle exception, rely on legacy query
  285. # TODO(gst): Implement updated introspection workaround
  286. query = gql(GET_SWEEP_LEGACY_GQL)
  287. data = client.execute(query, variable_values=variables)
  288. # FIXME: looks like this method allows passing arbitrary GQL queries, so for now
  289. # we'll have to skip trying to validate the result with a generated pydantic model.
  290. if not (
  291. data
  292. and (proj_dict := data.get("project"))
  293. and (sweep_dict := proj_dict.get("sweep"))
  294. ):
  295. return None
  296. sweep = cls(client, entity, project, sid, attrs=sweep_dict)
  297. sweep.runs = public.Runs(
  298. client,
  299. entity,
  300. project,
  301. order=order,
  302. per_page=10,
  303. filters={"$and": [{"sweep": sweep.id}]},
  304. )
  305. return sweep
  306. def _make_sweep_agent(self, attrs: Mapping[str, Any]) -> Agent:
  307. """Construct `Agent` from API payload."""
  308. try:
  309. return Agent(
  310. self.client,
  311. attrs=attrs,
  312. entity=self.entity,
  313. project=self.project,
  314. sweep_id=self.id,
  315. )
  316. except ValueError as e:
  317. raise Error(
  318. "Sweep agent data from the W&B API was incomplete or invalid.",
  319. context={"details": str(e)},
  320. ) from e
  321. def agent(self, agent_id: str) -> Agent:
  322. """Query an agent by ID for this sweep.
  323. Args:
  324. agent_id: The ID of the agent to look up.
  325. """
  326. from wandb.apis._generated import GET_SWEEP_AGENT_GQL
  327. variables = {
  328. "agentID": agent_id,
  329. "sweep": self.id,
  330. "entity": self.entity,
  331. "project": self.project,
  332. }
  333. data = self.client.execute(gql(GET_SWEEP_AGENT_GQL), variable_values=variables)
  334. return self._make_sweep_agent(data["project"]["sweep"]["agent"])
  335. def agents(self) -> list[Agent]:
  336. """Query the list of all agents for this sweep."""
  337. from wandb.apis._generated import GET_SWEEP_AGENTS_GQL, GetSweepAgents
  338. variables = {
  339. "sweep": self.id,
  340. "entity": self.entity,
  341. "project": self.project,
  342. }
  343. data = self.client.execute(gql(GET_SWEEP_AGENTS_GQL), variable_values=variables)
  344. parsed = GetSweepAgents.model_validate(data)
  345. if not parsed.project or not parsed.project.sweep:
  346. return []
  347. return [
  348. self._make_sweep_agent(edge.node.model_dump(by_alias=True))
  349. for edge in parsed.project.sweep.agents.edges
  350. ]
  351. def to_html(self, height: int = 420, hidden: bool = False) -> str:
  352. """Generate HTML containing an iframe displaying this sweep."""
  353. url = self.url + "?jupyter=true"
  354. style = f"border:none;width:100%;height:{height}px;"
  355. prefix = ""
  356. if hidden:
  357. style += "display:none;"
  358. prefix = ipython.toggle_button("sweep")
  359. return prefix + f"<iframe src={url!r} style={style!r}></iframe>"
  360. def _repr_html_(self) -> str:
  361. return self.to_html()
  362. def __repr__(self) -> str:
  363. pathstr = "/".join(self.path)
  364. state = self._attrs.get("state", "Unknown State")
  365. return f"<Sweep {pathstr} ({state})>"
  366. class Agent(Attrs):
  367. def __init__(
  368. self,
  369. client: RetryingClient,
  370. attrs: Mapping[str, Any],
  371. entity: str,
  372. project: str,
  373. sweep_id: str,
  374. ) -> None:
  375. super().__init__(dict(attrs or {}))
  376. self._client = client
  377. self._entity = entity
  378. self._project = project
  379. self._sweep_id = sweep_id
  380. if self._entity is None:
  381. raise ValueError(
  382. "Agent requires entity. "
  383. "Use an Agent returned from sweep.agent(...) or sweep.agents()."
  384. )
  385. if self._project is None:
  386. raise ValueError(
  387. "Agent requires project. "
  388. "Use an Agent returned from sweep.agent(...) or sweep.agents()."
  389. )
  390. if self._sweep_id is None:
  391. raise ValueError(
  392. "Agent requires sweep_id. "
  393. "Use an Agent returned from sweep.agent(...) or sweep.agents()."
  394. )
  395. if not (self._attrs.get("name") or self._attrs.get("id")):
  396. if self._attrs.get("name") is None:
  397. raise ValueError("Agent is missing name.")
  398. if self._attrs.get("id") is None:
  399. raise ValueError("Agent is missing id.")
  400. raise ValueError("Agent is missing a usable name or id.")
  401. self._agent_key: str = self._attrs.get("name") or self._attrs.get("id")
  402. def runs(
  403. self,
  404. per_page: int = 50,
  405. ) -> AgentRuns:
  406. """Return a paginated collection of runs executed by this agent."""
  407. from wandb.apis.public.runs import AgentRuns
  408. total_runs = int(self._attrs.get("totalRuns") or 0)
  409. return AgentRuns(
  410. self._client,
  411. entity=self._entity,
  412. project=self._project,
  413. sweep_id=self._sweep_id,
  414. agent_key=self._agent_key,
  415. total_runs=total_runs,
  416. order="+created_at",
  417. per_page=per_page,
  418. )
  419. def __repr__(self) -> str:
  420. state = self._attrs.get("state", "Unknown State")
  421. name = self._attrs.get("id", "Unknown")
  422. return f"<Agent {name} ({state})>"