paginator.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from collections.abc import Iterable, Iterator, Mapping, Sized
  4. from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload
  5. import wandb
  6. from wandb._strutils import nameof
  7. if TYPE_CHECKING:
  8. from wandb_graphql.language.ast import Document
  9. from wandb._pydantic import Connection
  10. from wandb.apis.public.api import RetryingClient
  11. _WandbT = TypeVar("_WandbT")
  12. """Generic type variable for a W&B object."""
  13. _NodeT = TypeVar("_NodeT")
  14. """Generic type variable for a parsed GraphQL relay node."""
  15. class Paginator(Iterator[_WandbT], ABC):
  16. """An iterator for paginated objects from GraphQL requests."""
  17. QUERY: Document | ClassVar[Document | None]
  18. def __init__(
  19. self,
  20. client: RetryingClient,
  21. variables: Mapping[str, Any],
  22. per_page: int = 50, # We don't allow unbounded paging
  23. ):
  24. self.client = client
  25. # shallow copy partly guards against mutating the original input
  26. self.variables: dict[str, Any] = dict(variables)
  27. self.per_page: int = per_page
  28. self.objects: list[_WandbT] = []
  29. self.index: int = -1
  30. self.last_response: object | None = None
  31. def __iter__(self) -> Iterator[_WandbT]:
  32. self.index = -1
  33. return self
  34. @property
  35. @abstractmethod
  36. def more(self) -> bool:
  37. """Whether there are more pages to be fetched."""
  38. raise NotImplementedError
  39. @property
  40. @abstractmethod
  41. def cursor(self) -> str | None:
  42. """The start cursor to use for the next fetched page."""
  43. raise NotImplementedError
  44. @abstractmethod
  45. def convert_objects(self) -> Iterable[_WandbT]:
  46. """Convert the last fetched response data into the iterated objects."""
  47. raise NotImplementedError
  48. def update_variables(self) -> None:
  49. """Update the query variables for the next page fetch."""
  50. self.variables.update({"perPage": self.per_page, "cursor": self.cursor})
  51. def _update_response(self) -> None:
  52. """Fetch and store the response data for the next page."""
  53. self.last_response = self.client.execute(
  54. self.QUERY, variable_values=self.variables
  55. )
  56. def _load_page(self) -> bool:
  57. """Fetch the next page, if any, returning True and storing the response if there was one."""
  58. if not self.more:
  59. return False
  60. self.update_variables()
  61. self._update_response()
  62. self.objects.extend(self.convert_objects())
  63. return True
  64. @overload
  65. def __getitem__(self, index: int) -> _WandbT: ...
  66. @overload
  67. def __getitem__(self, index: slice) -> list[_WandbT]: ...
  68. def __getitem__(self, index: int | slice) -> _WandbT | list[_WandbT]:
  69. loaded = True
  70. stop = index.stop if isinstance(index, slice) else index
  71. while loaded and stop > len(self.objects) - 1:
  72. loaded = self._load_page()
  73. return self.objects[index]
  74. def __next__(self) -> _WandbT:
  75. self.index += 1
  76. if len(self.objects) <= self.index:
  77. if not self._load_page():
  78. raise StopIteration
  79. if len(self.objects) <= self.index:
  80. raise StopIteration
  81. return self.objects[self.index]
  82. next = __next__
  83. class SizedPaginator(Paginator[_WandbT], Sized, ABC):
  84. """A Paginator for objects with a known total count."""
  85. @property
  86. def length(self) -> int | None:
  87. wandb.termwarn(
  88. (
  89. "`.length` is deprecated and will be removed in a future version. "
  90. "Use `len(...)` instead."
  91. ),
  92. repeat=False,
  93. )
  94. return len(self)
  95. def __len__(self) -> int:
  96. if self._length is None:
  97. self._load_page()
  98. if self._length is None:
  99. raise ValueError("Object doesn't provide length")
  100. return self._length
  101. @property
  102. @abstractmethod
  103. def _length(self) -> int | None:
  104. raise NotImplementedError
  105. class RelayPaginator(Paginator[_WandbT], Generic[_NodeT, _WandbT], ABC):
  106. """A Paginator for GQL relay-style nodes parsed via Pydantic.
  107. <!-- lazydoc-ignore-class: internal -->
  108. """
  109. last_response: Connection[_NodeT] | None
  110. @property
  111. def more(self) -> bool:
  112. return (conn := self.last_response) is None or conn.has_next
  113. @property
  114. def cursor(self) -> str | None:
  115. return conn.next_cursor if (conn := self.last_response) else None
  116. @abstractmethod
  117. def _convert(self, node: _NodeT) -> _WandbT | Any:
  118. """Convert a parsed GraphQL node into the iterated object.
  119. If a falsey value is returned, it will be skipped during iteration.
  120. """
  121. raise NotImplementedError
  122. def convert_objects(self) -> Iterable[_WandbT]:
  123. # Default implementation. Subclasses can override this if if more complex
  124. # logic is needed, but ideally most shouldn't need to.
  125. if conn := self.last_response:
  126. yield from filter(None, map(self._convert, conn.nodes()))
  127. class SizedRelayPaginator(RelayPaginator[_NodeT, _WandbT], Sized, ABC):
  128. """A Paginator for GQL nodes parsed via Pydantic, with a known total count.
  129. <!-- lazydoc-ignore-class: internal -->
  130. """
  131. last_response: Connection[_NodeT] | None
  132. def __len__(self) -> int:
  133. """Returns the total number of objects to expect."""
  134. # If the first page hasn't been fetched yet, do that first
  135. if self.last_response is None:
  136. self._load_page()
  137. if (conn := self.last_response) and (total := conn.total_count) is not None:
  138. return total
  139. raise NotImplementedError(f"{nameof(type(self))!r} doesn't provide length")