websocket.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. """Base websocket classes."""
  2. import re
  3. import warnings
  4. from typing import Optional, no_type_check
  5. from urllib.parse import urlparse
  6. from tornado import ioloop, web
  7. from tornado.iostream import IOStream
  8. from jupyter_server.base.handlers import JupyterHandler
  9. from jupyter_server.utils import JupyterServerAuthWarning
  10. # ping interval for keeping websockets alive (30 seconds)
  11. WS_PING_INTERVAL = 30000
  12. class WebSocketMixin:
  13. """Mixin for common websocket options"""
  14. ping_callback = None
  15. last_ping = 0.0
  16. last_pong = 0.0
  17. stream: Optional[IOStream] = None
  18. @property
  19. def ping_interval(self):
  20. """The interval for websocket keep-alive pings.
  21. Set ws_ping_interval = 0 to disable pings.
  22. """
  23. return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
  24. @property
  25. def ping_timeout(self):
  26. """If no ping is received in this many milliseconds,
  27. close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
  28. Default is max of 3 pings or 30 seconds.
  29. """
  30. return self.settings.get( # type:ignore[attr-defined]
  31. "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
  32. )
  33. @no_type_check
  34. def check_origin(self, origin: Optional[str] = None) -> bool:
  35. """Check Origin == Host or Access-Control-Allow-Origin.
  36. Tornado >= 4 calls this method automatically, raising 403 if it returns False.
  37. """
  38. if self.allow_origin == "*" or (
  39. hasattr(self, "skip_check_origin") and self.skip_check_origin()
  40. ):
  41. return True
  42. host = self.request.headers.get("Host")
  43. if origin is None:
  44. origin = self.get_origin()
  45. # If no origin or host header is provided, assume from script
  46. if origin is None or host is None:
  47. return True
  48. origin = origin.lower()
  49. origin_host = urlparse(origin).netloc
  50. # OK if origin matches host
  51. if origin_host == host:
  52. return True
  53. # Check CORS headers
  54. if self.allow_origin:
  55. allow = self.allow_origin == origin
  56. elif self.allow_origin_pat:
  57. allow = bool(re.match(self.allow_origin_pat, origin))
  58. else:
  59. # No CORS headers deny the request
  60. allow = False
  61. if not allow:
  62. self.log.warning(
  63. "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
  64. origin,
  65. host,
  66. )
  67. return allow
  68. def clear_cookie(self, *args, **kwargs):
  69. """meaningless for websockets"""
  70. @no_type_check
  71. def _maybe_auth(self):
  72. """Verify authentication if required.
  73. Only used when the websocket class does not inherit from JupyterHandler.
  74. """
  75. if not self.settings.get("allow_unauthenticated_access", False):
  76. if not self.request.method:
  77. raise web.HTTPError(403)
  78. method = getattr(self, self.request.method.lower())
  79. if not getattr(method, "__allow_unauthenticated", False):
  80. # rather than re-using `web.authenticated` which also redirects
  81. # to login page on GET, just raise 403 if user is not known
  82. user = self.current_user
  83. if user is None:
  84. self.log.warning("Couldn't authenticate WebSocket connection")
  85. raise web.HTTPError(403)
  86. @no_type_check
  87. def prepare(self, *args, **kwargs):
  88. """Handle a get request."""
  89. if not isinstance(self, JupyterHandler):
  90. should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
  91. if "identity_provider" in self.settings and should_authenticate:
  92. warnings.warn(
  93. "WebSocketMixin sub-class does not inherit from JupyterHandler"
  94. " preventing proper authentication using custom identity provider.",
  95. JupyterServerAuthWarning,
  96. stacklevel=2,
  97. )
  98. self._maybe_auth()
  99. return super().prepare(*args, **kwargs)
  100. return super().prepare(*args, **kwargs, _redirect_to_login=False)
  101. @no_type_check
  102. def open(self, *args, **kwargs):
  103. """Open the websocket."""
  104. self.log.debug("Opening websocket %s", self.request.path)
  105. # start the pinging
  106. if self.ping_interval > 0:
  107. loop = ioloop.IOLoop.current()
  108. self.last_ping = loop.time() # Remember time of last ping
  109. self.last_pong = self.last_ping
  110. self.ping_callback = ioloop.PeriodicCallback(
  111. self.send_ping,
  112. self.ping_interval,
  113. )
  114. self.ping_callback.start()
  115. return super().open(*args, **kwargs)
  116. @no_type_check
  117. def send_ping(self):
  118. """send a ping to keep the websocket alive"""
  119. if self.ws_connection is None and self.ping_callback is not None:
  120. self.ping_callback.stop()
  121. return
  122. if self.ws_connection.client_terminated:
  123. self.close()
  124. return
  125. # check for timeout on pong. Make sure that we really have sent a recent ping in
  126. # case the machine with both server and client has been suspended since the last ping.
  127. now = ioloop.IOLoop.current().time()
  128. since_last_pong = 1e3 * (now - self.last_pong)
  129. since_last_ping = 1e3 * (now - self.last_ping)
  130. if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
  131. self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
  132. self.close()
  133. return
  134. self.ping(b"")
  135. self.last_ping = now
  136. def on_pong(self, data):
  137. """Handle a pong message."""
  138. self.last_pong = ioloop.IOLoop.current().time()