gql_request.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """A simple GraphQL client for sending queries and mutations.
  2. Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
  3. The only substantial change is to reuse a requests.Session object.
  4. """
  5. from __future__ import annotations
  6. from typing import Any, Callable
  7. from wandb_gql.transport.http import HTTPTransport
  8. from wandb_graphql.execution import ExecutionResult
  9. from wandb_graphql.language import ast
  10. from wandb_graphql.language.printer import print_ast
  11. from wandb._analytics import tracked_func
  12. class GraphQLSession(HTTPTransport):
  13. def __init__(
  14. self,
  15. url: str,
  16. auth: tuple[str, str] | Callable | None = None,
  17. use_json: bool = False,
  18. timeout: int | float | None = None,
  19. proxies: dict[str, str] | None = None,
  20. **kwargs: Any,
  21. ) -> None:
  22. """Setup a session for sending GraphQL queries and mutations.
  23. Args:
  24. url (str): The GraphQL URL
  25. auth (tuple or callable): Auth tuple or callable for Basic/Digest/Custom HTTP Auth
  26. use_json (bool): Send request body as JSON instead of form-urlencoded
  27. timeout (int, float): Specifies a default timeout for requests (Default: None)
  28. """
  29. import requests
  30. super().__init__(url, **kwargs)
  31. self.session = requests.Session()
  32. if proxies:
  33. self.session.proxies.update(proxies)
  34. self.session.auth = auth
  35. self.default_timeout = timeout
  36. self.use_json = use_json
  37. def execute(
  38. self,
  39. document: ast.Node,
  40. variable_values: dict[str, Any] | None = None,
  41. timeout: int | float | None = None,
  42. ) -> ExecutionResult:
  43. query_str = print_ast(document)
  44. payload = {"query": query_str, "variables": variable_values or {}}
  45. data_key = "json" if self.use_json else "data"
  46. headers = self.headers.copy() if self.headers else {}
  47. # If we're tracking a calling python function, include it in the headers
  48. if func_info := tracked_func():
  49. headers.update(func_info.to_headers())
  50. post_args = {
  51. "headers": headers or None,
  52. "cookies": self.cookies,
  53. "timeout": timeout or self.default_timeout,
  54. data_key: payload,
  55. }
  56. request = self.session.post(self.url, **post_args)
  57. request.raise_for_status()
  58. result = request.json()
  59. data, errors = result.get("data"), result.get("errors")
  60. if data is None and errors is None:
  61. raise RuntimeError(f"Received non-compatible response: {result}")
  62. return ExecutionResult(data=data, errors=errors)