| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830 |
- """A kernel gateway client."""
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- from __future__ import annotations
- import asyncio
- import json
- import logging
- import os
- import typing as ty
- from abc import ABC, ABCMeta, abstractmethod
- from datetime import datetime, timezone
- from email.utils import parsedate_to_datetime
- from http.cookies import SimpleCookie
- from socket import gaierror
- from jupyter_events import EventLogger
- from tornado import web
- from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
- from traitlets import (
- Bool,
- Float,
- Instance,
- Int,
- TraitError,
- Type,
- Unicode,
- default,
- observe,
- validate,
- )
- from traitlets.config import LoggingConfigurable, SingletonConfigurable
- from jupyter_server import DEFAULT_EVENTS_SCHEMA_PATH, JUPYTER_SERVER_EVENTS_URI
- ERROR_STATUS = "error"
- SUCCESS_STATUS = "success"
- STATUS_KEY = "status"
- STATUS_CODE_KEY = "status_code"
- MESSAGE_KEY = "msg"
- if ty.TYPE_CHECKING:
- from http.cookies import Morsel
- class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore[misc]
- """The metaclass necessary for proper ABC behavior in a Configurable."""
- class GatewayTokenRenewerBase( # type:ignore[misc]
- ABC, LoggingConfigurable, metaclass=GatewayTokenRenewerMeta
- ):
- """
- Abstract base class for refreshing tokens used between this server and a Gateway
- server. Implementations requiring additional configuration can extend their class
- with appropriate configuration values or convey those values via appropriate
- environment variables relative to the implementation.
- """
- @abstractmethod
- def get_token(
- self,
- auth_header_key: str,
- auth_scheme: ty.Union[str, None],
- auth_token: str,
- **kwargs: ty.Any,
- ) -> str:
- """
- Given the current authorization header key, scheme, and token, this method returns
- a (potentially renewed) token for use against the Gateway server.
- """
- class NoOpTokenRenewer(GatewayTokenRenewerBase): # type:ignore[misc]
- """NoOpTokenRenewer is the default value to the GatewayClient trait
- `gateway_token_renewer` and merely returns the provided token.
- """
- def get_token(
- self,
- auth_header_key: str,
- auth_scheme: ty.Union[str, None],
- auth_token: str,
- **kwargs: ty.Any,
- ) -> str:
- """This implementation simply returns the current authorization token."""
- return auth_token
- class GatewayClient(SingletonConfigurable):
- """This class manages the configuration. It's its own singleton class so
- that we can share these values across all objects. It also contains some
- options.
- helper methods to build request arguments out of the various config
- """
- event_schema_id = JUPYTER_SERVER_EVENTS_URI + "/gateway_client/v1"
- event_logger = Instance(EventLogger).tag(config=True)
- @default("event_logger")
- def _default_event_logger(self):
- if self.parent and hasattr(self.parent, "event_logger"):
- # Event logger is attached from serverapp.
- return self.parent.event_logger
- else:
- # If parent does not have an event logger, create one.
- logger = EventLogger()
- schema_path = DEFAULT_EVENTS_SCHEMA_PATH / "gateway_client" / "v1.yaml"
- logger.register_event_schema(schema_path)
- self.log.info("Event is registered in GatewayClient.")
- return logger
- def emit(self, data):
- """Emit event using the core event schema from Jupyter Server's Gateway Client."""
- self.event_logger.emit(schema_id=self.event_schema_id, data=data)
- url = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The url of the Kernel or Enterprise Gateway server where
- kernel specifications are defined and kernel management takes place.
- If defined, this Notebook server acts as a proxy for all kernel
- management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var)
- """,
- )
- url_env = "JUPYTER_GATEWAY_URL"
- @default("url")
- def _url_default(self):
- return os.environ.get(self.url_env)
- @validate("url")
- def _url_validate(self, proposal):
- value = proposal["value"]
- # Ensure value, if present, starts with 'http'
- if value is not None and len(value) > 0 and not str(value).lower().startswith("http"):
- message = "GatewayClient url must start with 'http': '%r'" % value
- self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
- raise TraitError(message)
- return value
- ws_url = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value
- will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var)
- """,
- )
- ws_url_env = "JUPYTER_GATEWAY_WS_URL"
- @default("ws_url")
- def _ws_url_default(self):
- default_value = os.environ.get(self.ws_url_env)
- if self.url is not None and default_value is None and self.gateway_enabled:
- default_value = self.url.lower().replace("http", "ws")
- return default_value
- @validate("ws_url")
- def _ws_url_validate(self, proposal):
- value = proposal["value"]
- # Ensure value, if present, starts with 'ws'
- if value is not None and len(value) > 0 and not str(value).lower().startswith("ws"):
- message = "GatewayClient ws_url must start with 'ws': '%r'" % value
- self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
- raise TraitError(message)
- return value
- kernels_endpoint_default_value = "/api/kernels"
- kernels_endpoint_env = "JUPYTER_GATEWAY_KERNELS_ENDPOINT"
- kernels_endpoint = Unicode(
- default_value=kernels_endpoint_default_value,
- config=True,
- help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""",
- )
- @default("kernels_endpoint")
- def _kernels_endpoint_default(self):
- return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value)
- kernelspecs_endpoint_default_value = "/api/kernelspecs"
- kernelspecs_endpoint_env = "JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT"
- kernelspecs_endpoint = Unicode(
- default_value=kernelspecs_endpoint_default_value,
- config=True,
- help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""",
- )
- @default("kernelspecs_endpoint")
- def _kernelspecs_endpoint_default(self):
- return os.environ.get(
- self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value
- )
- kernelspecs_resource_endpoint_default_value = "/kernelspecs"
- kernelspecs_resource_endpoint_env = "JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT"
- kernelspecs_resource_endpoint = Unicode(
- default_value=kernelspecs_resource_endpoint_default_value,
- config=True,
- help="""The gateway endpoint for accessing kernelspecs resources
- (JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""",
- )
- @default("kernelspecs_resource_endpoint")
- def _kernelspecs_resource_endpoint_default(self):
- return os.environ.get(
- self.kernelspecs_resource_endpoint_env,
- self.kernelspecs_resource_endpoint_default_value,
- )
- connect_timeout_default_value = 40.0
- connect_timeout_env = "JUPYTER_GATEWAY_CONNECT_TIMEOUT"
- connect_timeout = Float(
- default_value=connect_timeout_default_value,
- config=True,
- help="""The time allowed for HTTP connection establishment with the Gateway server.
- (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""",
- )
- @default("connect_timeout")
- def _connect_timeout_default(self):
- return float(os.environ.get(self.connect_timeout_env, self.connect_timeout_default_value))
- request_timeout_default_value = 42.0
- request_timeout_env = "JUPYTER_GATEWAY_REQUEST_TIMEOUT"
- request_timeout = Float(
- default_value=request_timeout_default_value,
- config=True,
- help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""",
- )
- @default("request_timeout")
- def _request_timeout_default(self):
- return float(os.environ.get(self.request_timeout_env, self.request_timeout_default_value))
- client_key = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var)
- """,
- )
- client_key_env = "JUPYTER_GATEWAY_CLIENT_KEY"
- @default("client_key")
- def _client_key_default(self):
- return os.environ.get(self.client_key_env)
- client_cert = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var)
- """,
- )
- client_cert_env = "JUPYTER_GATEWAY_CLIENT_CERT"
- @default("client_cert")
- def _client_cert_default(self):
- return os.environ.get(self.client_cert_env)
- ca_certs = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var)
- """,
- )
- ca_certs_env = "JUPYTER_GATEWAY_CA_CERTS"
- @default("ca_certs")
- def _ca_certs_default(self):
- return os.environ.get(self.ca_certs_env)
- http_user = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var)
- """,
- )
- http_user_env = "JUPYTER_GATEWAY_HTTP_USER"
- @default("http_user")
- def _http_user_default(self):
- return os.environ.get(self.http_user_env)
- http_pwd = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var)
- """,
- )
- http_pwd_env = "JUPYTER_GATEWAY_HTTP_PWD" # noqa: S105
- @default("http_pwd")
- def _http_pwd_default(self):
- return os.environ.get(self.http_pwd_env)
- headers_default_value = "{}"
- headers_env = "JUPYTER_GATEWAY_HEADERS"
- headers = Unicode(
- default_value=headers_default_value,
- allow_none=True,
- config=True,
- help="""Additional HTTP headers to pass on the request. This value will be converted to a dict.
- (JUPYTER_GATEWAY_HEADERS env var)
- """,
- )
- @default("headers")
- def _headers_default(self):
- return os.environ.get(self.headers_env, self.headers_default_value)
- auth_header_key_default_value = "Authorization"
- auth_header_key = Unicode(
- config=True,
- help="""The authorization header's key name (typically 'Authorization') used in the HTTP headers. The
- header will be formatted as::
- {'{auth_header_key}': '{auth_scheme} {auth_token}'}
- If the authorization header key takes a single value, `auth_scheme` should be set to None and
- 'auth_token' should be configured to use the appropriate value.
- (JUPYTER_GATEWAY_AUTH_HEADER_KEY env var)""",
- )
- auth_header_key_env = "JUPYTER_GATEWAY_AUTH_HEADER_KEY"
- @default("auth_header_key")
- def _auth_header_key_default(self):
- return os.environ.get(self.auth_header_key_env, self.auth_header_key_default_value)
- auth_token_default_value = ""
- auth_token = Unicode(
- default_value=None,
- allow_none=True,
- config=True,
- help="""The authorization token used in the HTTP headers. The header will be formatted as::
- {'{auth_header_key}': '{auth_scheme} {auth_token}'}
- (JUPYTER_GATEWAY_AUTH_TOKEN env var)""",
- )
- auth_token_env = "JUPYTER_GATEWAY_AUTH_TOKEN" # noqa: S105
- @default("auth_token")
- def _auth_token_default(self):
- return os.environ.get(self.auth_token_env, self.auth_token_default_value)
- auth_scheme_default_value = "token" # This value is purely for backwards compatibility
- auth_scheme = Unicode(
- allow_none=True,
- config=True,
- help="""The auth scheme, added as a prefix to the authorization token used in the HTTP headers.
- (JUPYTER_GATEWAY_AUTH_SCHEME env var)""",
- )
- auth_scheme_env = "JUPYTER_GATEWAY_AUTH_SCHEME"
- @default("auth_scheme")
- def _auth_scheme_default(self):
- return os.environ.get(self.auth_scheme_env, self.auth_scheme_default_value)
- validate_cert_default_value = True
- validate_cert_env = "JUPYTER_GATEWAY_VALIDATE_CERT"
- validate_cert = Bool(
- default_value=validate_cert_default_value,
- config=True,
- help="""For HTTPS requests, determines if server's certificate should be validated or not.
- (JUPYTER_GATEWAY_VALIDATE_CERT env var)""",
- )
- @default("validate_cert")
- def _validate_cert_default(self):
- return bool(
- os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value))
- not in ["no", "false"]
- )
- allowed_envs_default_value = ""
- allowed_envs_env = "JUPYTER_GATEWAY_ALLOWED_ENVS"
- allowed_envs = Unicode(
- default_value=allowed_envs_default_value,
- config=True,
- help="""A comma-separated list of environment variable names that will be included, along with
- their values, in the kernel startup request. The corresponding `client_envs` configuration
- value must also be set on the Gateway server - since that configuration value indicates which
- environmental values to make available to the kernel. (JUPYTER_GATEWAY_ALLOWED_ENVS env var)""",
- )
- @default("allowed_envs")
- def _allowed_envs_default(self):
- return os.environ.get(
- self.allowed_envs_env,
- os.environ.get("JUPYTER_GATEWAY_ENV_WHITELIST", self.allowed_envs_default_value),
- )
- env_whitelist = Unicode(
- default_value=allowed_envs_default_value,
- config=True,
- help="""Deprecated, use `GatewayClient.allowed_envs`""",
- )
- gateway_retry_interval_default_value = 1.0
- gateway_retry_interval_env = "JUPYTER_GATEWAY_RETRY_INTERVAL"
- gateway_retry_interval = Float(
- default_value=gateway_retry_interval_default_value,
- config=True,
- help="""The time allowed for HTTP reconnection with the Gateway server for the first time.
- Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries
- but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX.
- (JUPYTER_GATEWAY_RETRY_INTERVAL env var)""",
- )
- @default("gateway_retry_interval")
- def _gateway_retry_interval_default(self):
- return float(
- os.environ.get(
- self.gateway_retry_interval_env,
- self.gateway_retry_interval_default_value,
- )
- )
- gateway_retry_interval_max_default_value = 30.0
- gateway_retry_interval_max_env = "JUPYTER_GATEWAY_RETRY_INTERVAL_MAX"
- gateway_retry_interval_max = Float(
- default_value=gateway_retry_interval_max_default_value,
- config=True,
- help="""The maximum time allowed for HTTP reconnection retry with the Gateway server.
- (JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""",
- )
- @default("gateway_retry_interval_max")
- def _gateway_retry_interval_max_default(self):
- return float(
- os.environ.get(
- self.gateway_retry_interval_max_env,
- self.gateway_retry_interval_max_default_value,
- )
- )
- gateway_retry_max_default_value = 5
- gateway_retry_max_env = "JUPYTER_GATEWAY_RETRY_MAX"
- gateway_retry_max = Int(
- default_value=gateway_retry_max_default_value,
- config=True,
- help="""The maximum retries allowed for HTTP reconnection with the Gateway server.
- (JUPYTER_GATEWAY_RETRY_MAX env var)""",
- )
- @default("gateway_retry_max")
- def _gateway_retry_max_default(self):
- return int(os.environ.get(self.gateway_retry_max_env, self.gateway_retry_max_default_value))
- gateway_token_renewer_class_default_value = (
- "jupyter_server.gateway.gateway_client.NoOpTokenRenewer" # noqa: S105
- )
- gateway_token_renewer_class_env = "JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS" # noqa: S105
- gateway_token_renewer_class = Type(
- klass=GatewayTokenRenewerBase,
- config=True,
- help="""The class to use for Gateway token renewal. (JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS env var)""",
- )
- @default("gateway_token_renewer_class")
- def _gateway_token_renewer_class_default(self):
- return os.environ.get(
- self.gateway_token_renewer_class_env, self.gateway_token_renewer_class_default_value
- )
- launch_timeout_pad_default_value = 2.0
- launch_timeout_pad_env = "JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD"
- launch_timeout_pad = Float(
- default_value=launch_timeout_pad_default_value,
- config=True,
- help="""Timeout pad to be ensured between KERNEL_LAUNCH_TIMEOUT and request_timeout
- such that request_timeout >= KERNEL_LAUNCH_TIMEOUT + launch_timeout_pad.
- (JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD env var)""",
- )
- @default("launch_timeout_pad")
- def _launch_timeout_pad_default(self):
- return float(
- os.environ.get(
- self.launch_timeout_pad_env,
- self.launch_timeout_pad_default_value,
- )
- )
- accept_cookies_value = False
- accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES"
- accept_cookies = Bool(
- default_value=accept_cookies_value,
- config=True,
- help="""Accept and manage cookies sent by the service side. This is often useful
- for load balancers to decide which backend node to use.
- (JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""",
- )
- @default("accept_cookies")
- def _accept_cookies_default(self):
- return bool(
- os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower())
- not in ["no", "false"]
- )
- _deprecated_traits = {
- "env_whitelist": ("allowed_envs", "2.0"),
- }
- # Method copied from
- # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
- @observe(*list(_deprecated_traits))
- def _deprecated_trait(self, change):
- """observer for deprecated traits"""
- old_attr = change.name
- new_attr, version = self._deprecated_traits[old_attr]
- new_value = getattr(self, new_attr)
- if new_value != change.new:
- # only warn if different
- # protects backward-compatible config from warnings
- # if they set the same value under both names
- self.log.warning(
- f"{self.__class__.__name__}.{old_attr} is deprecated in jupyter_server "
- f"{version}, use {self.__class__.__name__}.{new_attr} instead"
- )
- setattr(self, new_attr, change.new)
- @property
- def gateway_enabled(self):
- return bool(self.url is not None and len(self.url) > 0)
- # Ensure KERNEL_LAUNCH_TIMEOUT has a default value.
- KERNEL_LAUNCH_TIMEOUT = int(os.environ.get("KERNEL_LAUNCH_TIMEOUT", 40))
- _connection_args: dict[str, ty.Any] # initialized on first use
- gateway_token_renewer: GatewayTokenRenewerBase
- def __init__(self, **kwargs):
- """Initialize a gateway client."""
- super().__init__(**kwargs)
- self._connection_args = {} # initialized on first use
- self.gateway_token_renewer = self.gateway_token_renewer_class(parent=self, log=self.log) # type:ignore[abstract]
- # store of cookies with store time
- self._cookies: dict[str, tuple[Morsel[ty.Any], datetime]] = {}
- def init_connection_args(self):
- """Initialize arguments used on every request. Since these are primarily static values,
- we'll perform this operation once.
- """
- # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are in sync, taking the
- # greater value of the two and taking into account the following relation:
- # request_timeout = KERNEL_LAUNCH_TIME + padding
- minimum_request_timeout = (
- float(GatewayClient.KERNEL_LAUNCH_TIMEOUT) + self.launch_timeout_pad
- )
- if self.request_timeout < minimum_request_timeout:
- self.request_timeout = minimum_request_timeout
- elif self.request_timeout > minimum_request_timeout:
- GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(
- self.request_timeout - self.launch_timeout_pad
- )
- # Ensure any adjustments are reflected in env.
- os.environ["KERNEL_LAUNCH_TIMEOUT"] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT)
- if self.headers:
- self._connection_args["headers"] = json.loads(self.headers)
- if self.auth_header_key not in self._connection_args["headers"]:
- self._connection_args["headers"].update(
- {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
- )
- self._connection_args["connect_timeout"] = self.connect_timeout
- self._connection_args["request_timeout"] = self.request_timeout
- self._connection_args["validate_cert"] = self.validate_cert
- if self.client_cert:
- self._connection_args["client_cert"] = self.client_cert
- self._connection_args["client_key"] = self.client_key
- if self.ca_certs:
- self._connection_args["ca_certs"] = self.ca_certs
- if self.http_user:
- self._connection_args["auth_username"] = self.http_user
- if self.http_pwd:
- self._connection_args["auth_password"] = self.http_pwd
- def load_connection_args(self, **kwargs):
- """Merges the static args relative to the connection, with the given keyword arguments. If static
- args have yet to be initialized, we'll do that here.
- """
- if len(self._connection_args) == 0:
- self.init_connection_args()
- # Give token renewal a shot at renewing the token
- prev_auth_token = self.auth_token
- if self.auth_token is not None:
- try:
- self.auth_token = self.gateway_token_renewer.get_token(
- self.auth_header_key, self.auth_scheme, self.auth_token
- )
- except Exception as ex:
- self.log.error(
- f"An exception occurred attempting to renew the "
- f"Gateway authorization token using an instance of class "
- f"'{self.gateway_token_renewer_class}'. The request will "
- f"proceed using the current token value. Exception was: {ex}"
- )
- self.auth_token = prev_auth_token
- for arg, value in self._connection_args.items():
- if arg == "headers":
- given_value = kwargs.setdefault(arg, {})
- if isinstance(given_value, dict):
- given_value.update(value)
- # Ensure the auth header is current
- given_value.update(
- {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
- )
- else:
- kwargs[arg] = value
- if self.accept_cookies:
- self._update_cookie_header(kwargs)
- return kwargs
- def update_cookies(self, cookie: SimpleCookie) -> None:
- """Update cookies from existing requests for load balancers"""
- if not self.accept_cookies:
- return
- store_time = datetime.now(tz=timezone.utc)
- for key, item in cookie.items():
- # Convert "expires" arg into "max-age" to facilitate expiration management.
- # As "max-age" has precedence, ignore "expires" when "max-age" exists.
- if item.get("expires") and not item.get("max-age"):
- expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
- item["max-age"] = str(expire_timedelta.total_seconds())
- self._cookies[key] = (item, store_time)
- def _clear_expired_cookies(self) -> None:
- """Clear expired cookies."""
- check_time = datetime.now(tz=timezone.utc)
- expired_keys = []
- for key, (morsel, store_time) in self._cookies.items():
- cookie_max_age = morsel.get("max-age")
- if not cookie_max_age:
- continue
- expired_timedelta = check_time - store_time
- if expired_timedelta.total_seconds() > float(cookie_max_age):
- expired_keys.append(key)
- for key in expired_keys:
- self._cookies.pop(key)
- def _update_cookie_header(self, connection_args: dict[str, ty.Any]) -> None:
- """Update a cookie header."""
- self._clear_expired_cookies()
- gateway_cookie_values = "; ".join(
- f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items()
- )
- if gateway_cookie_values:
- headers = connection_args.get("headers", {})
- # As headers are case-insensitive, we get existing name of cookie header,
- # or use "Cookie" by default.
- cookie_header_name = next(
- (header_key for header_key in headers if header_key.lower() == "cookie"),
- "Cookie",
- )
- existing_cookie = headers.get(cookie_header_name)
- # merge gateway-managed cookies with cookies already in arguments
- if existing_cookie:
- gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values
- headers[cookie_header_name] = gateway_cookie_values
- connection_args["headers"] = headers
- class RetryableHTTPClient:
- """
- Inspired by urllib.util.Retry (https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html),
- this class is initialized with desired retry characteristics, uses a recursive method `fetch()` against an instance
- of `AsyncHTTPClient` which tracks the current retry count across applicable request retries.
- """
- MAX_RETRIES_DEFAULT = 2
- MAX_RETRIES_CAP = 10 # The upper limit to max_retries value.
- max_retries: int = int(os.getenv("JUPYTER_GATEWAY_MAX_REQUEST_RETRIES", MAX_RETRIES_DEFAULT))
- max_retries = max(0, min(max_retries, MAX_RETRIES_CAP)) # Enforce boundaries
- retried_methods: set[str] = {"GET", "DELETE"}
- retried_errors: set[int] = {502, 503, 504, 599}
- retried_exceptions: set[type] = {ConnectionError}
- backoff_factor: float = 0.1
- def __init__(self):
- """Initialize the retryable http client."""
- self.retry_count: int = 0
- self.client: AsyncHTTPClient = AsyncHTTPClient()
- async def fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
- """
- Retryable AsyncHTTPClient.fetch() method. When the request fails, this method will
- recurse up to max_retries times if the condition deserves a retry.
- """
- self.retry_count = 0
- return await self._fetch(endpoint, **kwargs)
- async def _fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
- """
- Performs the fetch against the contained AsyncHTTPClient instance and determines
- if retry is necessary on any exceptions. If so, retry is performed recursively.
- """
- try:
- response: HTTPResponse = await self.client.fetch(endpoint, **kwargs)
- except Exception as e:
- is_retryable: bool = await self._is_retryable(kwargs["method"], e)
- if not is_retryable:
- raise e
- logging.getLogger("ServerApp").info(
- f"Attempting retry ({self.retry_count}) against "
- f"endpoint '{endpoint}'. Retried error: '{e!r}'"
- )
- response = await self._fetch(endpoint, **kwargs)
- return response
- async def _is_retryable(self, method: str, exception: Exception) -> bool:
- """Determines if the given exception is retryable based on object's configuration."""
- if method not in self.retried_methods:
- return False
- if self.retry_count == self.max_retries:
- return False
- # Determine if error is retryable...
- if isinstance(exception, HTTPClientError):
- hce: HTTPClientError = exception
- if hce.code not in self.retried_errors:
- return False
- elif not any(isinstance(exception, error) for error in self.retried_exceptions):
- return False
- # Is retryable, wait for backoff, then increment count
- await asyncio.sleep(self.backoff_factor * (2**self.retry_count))
- self.retry_count += 1
- return True
- async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
- """Make an async request to kernel gateway endpoint, returns a response"""
- gateway_client = GatewayClient.instance()
- kwargs = gateway_client.load_connection_args(**kwargs)
- rhc = RetryableHTTPClient()
- try:
- response = await rhc.fetch(endpoint, **kwargs)
- gateway_client.emit(
- data={STATUS_KEY: SUCCESS_STATUS, STATUS_CODE_KEY: 200, MESSAGE_KEY: "success"}
- )
- # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
- # or the server is not running.
- # NOTE: We do this here since this handler is called during the server's startup and subsequent refreshes
- # of the tree view.
- except HTTPClientError as e:
- gateway_client.emit(
- data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: e.code, MESSAGE_KEY: str(e.message)}
- )
- error_reason = (
- f"Exception while attempting to connect to Gateway server url '{gateway_client.url}'"
- )
- error_message = e.message
- if e.response:
- try:
- error_payload = json.loads(e.response.body)
- error_reason = error_payload.get("reason") or error_reason
- error_message = error_payload.get("message") or error_message
- except json.decoder.JSONDecodeError:
- error_reason = e.response.body.decode()
- raise web.HTTPError(
- e.code,
- f"Error from Gateway: [{error_message}] {error_reason}. "
- "Ensure gateway url is valid and the Gateway instance is running.",
- ) from e
- except ConnectionError as e:
- gateway_client.emit(
- data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 503, MESSAGE_KEY: str(e)}
- )
- raise web.HTTPError(
- 503,
- f"ConnectionError was received from Gateway server url '{gateway_client.url}'. "
- "Check to be sure the Gateway instance is running.",
- ) from e
- except gaierror as e:
- gateway_client.emit(
- data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 404, MESSAGE_KEY: str(e)}
- )
- raise web.HTTPError(
- 404,
- f"The Gateway server specified in the gateway_url '{gateway_client.url}' doesn't "
- f"appear to be valid. Ensure gateway url is valid and the Gateway instance is running.",
- ) from e
- except Exception as e:
- gateway_client.emit(
- data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 505, MESSAGE_KEY: str(e)}
- )
- logging.getLogger("ServerApp").error(
- "Exception while trying to launch kernel via Gateway URL %s: %s",
- gateway_client.url,
- e,
- )
- raise e
- if gateway_client.accept_cookies:
- # Update cookies on GatewayClient from server if configured.
- cookie_values = response.headers.get("Set-Cookie")
- if cookie_values:
- cookie: SimpleCookie = SimpleCookie()
- cookie.load(cookie_values)
- gateway_client.update_cookies(cookie)
- return response
|