| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- """A simple GraphQL client for sending queries and mutations.
- Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
- The only substantial change is to reuse a requests.Session object.
- """
- from __future__ import annotations
- from typing import Any, Callable
- from wandb_gql.transport.http import HTTPTransport
- from wandb_graphql.execution import ExecutionResult
- from wandb_graphql.language import ast
- from wandb_graphql.language.printer import print_ast
- from wandb._analytics import tracked_func
- class GraphQLSession(HTTPTransport):
- def __init__(
- self,
- url: str,
- auth: tuple[str, str] | Callable | None = None,
- use_json: bool = False,
- timeout: int | float | None = None,
- proxies: dict[str, str] | None = None,
- **kwargs: Any,
- ) -> None:
- """Setup a session for sending GraphQL queries and mutations.
- Args:
- url (str): The GraphQL URL
- auth (tuple or callable): Auth tuple or callable for Basic/Digest/Custom HTTP Auth
- use_json (bool): Send request body as JSON instead of form-urlencoded
- timeout (int, float): Specifies a default timeout for requests (Default: None)
- """
- import requests
- super().__init__(url, **kwargs)
- self.session = requests.Session()
- if proxies:
- self.session.proxies.update(proxies)
- self.session.auth = auth
- self.default_timeout = timeout
- self.use_json = use_json
- def execute(
- self,
- document: ast.Node,
- variable_values: dict[str, Any] | None = None,
- timeout: int | float | None = None,
- ) -> ExecutionResult:
- query_str = print_ast(document)
- payload = {"query": query_str, "variables": variable_values or {}}
- data_key = "json" if self.use_json else "data"
- headers = self.headers.copy() if self.headers else {}
- # If we're tracking a calling python function, include it in the headers
- if func_info := tracked_func():
- headers.update(func_info.to_headers())
- post_args = {
- "headers": headers or None,
- "cookies": self.cookies,
- "timeout": timeout or self.default_timeout,
- data_key: payload,
- }
- request = self.session.post(self.url, **post_args)
- request.raise_for_status()
- result = request.json()
- data, errors = result.get("data"), result.get("errors")
- if data is None and errors is None:
- raise RuntimeError(f"Received non-compatible response: {result}")
- return ExecutionResult(data=data, errors=errors)
|