automations.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """W&B Public API for Automation objects."""
  2. from __future__ import annotations
  3. from collections.abc import Iterator, Mapping
  4. from itertools import chain
  5. from typing import TYPE_CHECKING, Any
  6. from pydantic import ValidationError
  7. from typing_extensions import override
  8. from wandb.apis.paginator import RelayPaginator
  9. if TYPE_CHECKING:
  10. from wandb_graphql.language.ast import Document
  11. from wandb._pydantic import Connection
  12. from wandb.apis.public.api import RetryingClient
  13. from wandb.automations import Automation
  14. from wandb.automations._generated import ProjectTriggersFields
  15. class Automations(RelayPaginator["ProjectTriggersFields", "Automation"]):
  16. """A lazy iterator of `Automation` objects.
  17. <!-- lazydoc-ignore-class: internal -->
  18. """
  19. QUERY: Document # Must be set per-instance
  20. last_response: Connection[ProjectTriggersFields] | None
  21. def __init__(
  22. self,
  23. client: RetryingClient,
  24. variables: Mapping[str, Any],
  25. per_page: int = 50,
  26. *,
  27. _query: Document, # internal use only, but required
  28. ):
  29. self.QUERY = _query
  30. super().__init__(client, variables=variables, per_page=per_page)
  31. @override
  32. def _update_response(self) -> None:
  33. """Fetch the raw response data for the current page."""
  34. from wandb._pydantic import Connection
  35. from wandb.automations._generated import ProjectTriggersFields
  36. data = self.client.execute(self.QUERY, variable_values=self.variables)
  37. try:
  38. conn_data = data["scope"]["projects"]
  39. conn = Connection[ProjectTriggersFields].model_validate(conn_data)
  40. self.last_response = conn
  41. except (LookupError, AttributeError, ValidationError) as e:
  42. raise ValueError("Unexpected response data") from e
  43. @override
  44. def _convert(self, node: ProjectTriggersFields) -> Iterator[Automation]:
  45. from wandb.automations import Automation
  46. return (Automation.model_validate(obj) for obj in node.triggers)
  47. @override
  48. def convert_objects(self) -> Iterator[Automation]:
  49. return chain.from_iterable(super().convert_objects())