| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- """Base websocket classes."""
- import re
- import warnings
- from typing import Optional, no_type_check
- from urllib.parse import urlparse
- from tornado import ioloop, web
- from tornado.iostream import IOStream
- from jupyter_server.base.handlers import JupyterHandler
- from jupyter_server.utils import JupyterServerAuthWarning
- # ping interval for keeping websockets alive (30 seconds)
- WS_PING_INTERVAL = 30000
- class WebSocketMixin:
- """Mixin for common websocket options"""
- ping_callback = None
- last_ping = 0.0
- last_pong = 0.0
- stream: Optional[IOStream] = None
- @property
- def ping_interval(self):
- """The interval for websocket keep-alive pings.
- Set ws_ping_interval = 0 to disable pings.
- """
- return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
- @property
- def ping_timeout(self):
- """If no ping is received in this many milliseconds,
- close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
- Default is max of 3 pings or 30 seconds.
- """
- return self.settings.get( # type:ignore[attr-defined]
- "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
- )
- @no_type_check
- def check_origin(self, origin: Optional[str] = None) -> bool:
- """Check Origin == Host or Access-Control-Allow-Origin.
- Tornado >= 4 calls this method automatically, raising 403 if it returns False.
- """
- if self.allow_origin == "*" or (
- hasattr(self, "skip_check_origin") and self.skip_check_origin()
- ):
- return True
- host = self.request.headers.get("Host")
- if origin is None:
- origin = self.get_origin()
- # If no origin or host header is provided, assume from script
- if origin is None or host is None:
- return True
- origin = origin.lower()
- origin_host = urlparse(origin).netloc
- # OK if origin matches host
- if origin_host == host:
- return True
- # Check CORS headers
- if self.allow_origin:
- allow = self.allow_origin == origin
- elif self.allow_origin_pat:
- allow = bool(re.match(self.allow_origin_pat, origin))
- else:
- # No CORS headers deny the request
- allow = False
- if not allow:
- self.log.warning(
- "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
- origin,
- host,
- )
- return allow
- def clear_cookie(self, *args, **kwargs):
- """meaningless for websockets"""
- @no_type_check
- def _maybe_auth(self):
- """Verify authentication if required.
- Only used when the websocket class does not inherit from JupyterHandler.
- """
- if not self.settings.get("allow_unauthenticated_access", False):
- if not self.request.method:
- raise web.HTTPError(403)
- method = getattr(self, self.request.method.lower())
- if not getattr(method, "__allow_unauthenticated", False):
- # rather than re-using `web.authenticated` which also redirects
- # to login page on GET, just raise 403 if user is not known
- user = self.current_user
- if user is None:
- self.log.warning("Couldn't authenticate WebSocket connection")
- raise web.HTTPError(403)
- @no_type_check
- def prepare(self, *args, **kwargs):
- """Handle a get request."""
- if not isinstance(self, JupyterHandler):
- should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
- if "identity_provider" in self.settings and should_authenticate:
- warnings.warn(
- "WebSocketMixin sub-class does not inherit from JupyterHandler"
- " preventing proper authentication using custom identity provider.",
- JupyterServerAuthWarning,
- stacklevel=2,
- )
- self._maybe_auth()
- return super().prepare(*args, **kwargs)
- return super().prepare(*args, **kwargs, _redirect_to_login=False)
- @no_type_check
- def open(self, *args, **kwargs):
- """Open the websocket."""
- self.log.debug("Opening websocket %s", self.request.path)
- # start the pinging
- if self.ping_interval > 0:
- loop = ioloop.IOLoop.current()
- self.last_ping = loop.time() # Remember time of last ping
- self.last_pong = self.last_ping
- self.ping_callback = ioloop.PeriodicCallback(
- self.send_ping,
- self.ping_interval,
- )
- self.ping_callback.start()
- return super().open(*args, **kwargs)
- @no_type_check
- def send_ping(self):
- """send a ping to keep the websocket alive"""
- if self.ws_connection is None and self.ping_callback is not None:
- self.ping_callback.stop()
- return
- if self.ws_connection.client_terminated:
- self.close()
- return
- # check for timeout on pong. Make sure that we really have sent a recent ping in
- # case the machine with both server and client has been suspended since the last ping.
- now = ioloop.IOLoop.current().time()
- since_last_pong = 1e3 * (now - self.last_pong)
- since_last_ping = 1e3 * (now - self.last_ping)
- if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
- self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
- self.close()
- return
- self.ping(b"")
- self.last_ping = now
- def on_pong(self, data):
- """Handle a pong message."""
- self.last_pong = ioloop.IOLoop.current().time()
|