websocket.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725
  1. """Implementation of the WebSocket protocol.
  2. `WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional
  3. communication between the browser and server. WebSockets are supported in the
  4. current versions of all major browsers.
  5. This module implements the final version of the WebSocket protocol as
  6. defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_.
  7. .. versionchanged:: 4.0
  8. Removed support for the draft 76 protocol version.
  9. """
  10. import abc
  11. import asyncio
  12. import base64
  13. import functools
  14. import hashlib
  15. import logging
  16. import os
  17. import sys
  18. import struct
  19. import tornado
  20. from urllib.parse import urlparse
  21. import warnings
  22. import zlib
  23. from tornado.concurrent import Future, future_set_result_unless_cancelled
  24. from tornado.escape import utf8, native_str, to_unicode
  25. from tornado import gen, httpclient, httputil
  26. from tornado.ioloop import IOLoop
  27. from tornado.iostream import StreamClosedError, IOStream
  28. from tornado.log import gen_log, app_log
  29. from tornado.netutil import Resolver
  30. from tornado import simple_httpclient
  31. from tornado.queues import Queue
  32. from tornado.tcpclient import TCPClient
  33. from tornado.util import _websocket_mask
  34. from typing import (
  35. TYPE_CHECKING,
  36. cast,
  37. Any,
  38. Optional,
  39. Dict,
  40. Union,
  41. List,
  42. Awaitable,
  43. Callable,
  44. Tuple,
  45. Type,
  46. )
  47. from types import TracebackType
  48. if TYPE_CHECKING:
  49. from typing_extensions import Protocol
  50. # The zlib compressor types aren't actually exposed anywhere
  51. # publicly, so declare protocols for the portions we use.
  52. class _Compressor(Protocol):
  53. def compress(self, data: bytes) -> bytes:
  54. pass
  55. def flush(self, mode: int) -> bytes:
  56. pass
  57. class _Decompressor(Protocol):
  58. unconsumed_tail = b"" # type: bytes
  59. def decompress(self, data: bytes, max_length: int) -> bytes:
  60. pass
  61. class _WebSocketDelegate(Protocol):
  62. # The common base interface implemented by WebSocketHandler on
  63. # the server side and WebSocketClientConnection on the client
  64. # side.
  65. def on_ws_connection_close(
  66. self, close_code: Optional[int] = None, close_reason: Optional[str] = None
  67. ) -> None:
  68. pass
  69. def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
  70. pass
  71. def on_ping(self, data: bytes) -> None:
  72. pass
  73. def on_pong(self, data: bytes) -> None:
  74. pass
  75. def log_exception(
  76. self,
  77. typ: Optional[Type[BaseException]],
  78. value: Optional[BaseException],
  79. tb: Optional[TracebackType],
  80. ) -> None:
  81. pass
  82. _default_max_message_size = 10 * 1024 * 1024
  83. # log to "gen_log" but suppress duplicate log messages
  84. de_dupe_gen_log = functools.lru_cache(gen_log.log)
  85. class WebSocketError(Exception):
  86. pass
  87. class WebSocketClosedError(WebSocketError):
  88. """Raised by operations on a closed connection.
  89. .. versionadded:: 3.2
  90. """
  91. pass
  92. class _DecompressTooLargeError(Exception):
  93. pass
  94. class _WebSocketParams:
  95. def __init__(
  96. self,
  97. ping_interval: Optional[float] = None,
  98. ping_timeout: Optional[float] = None,
  99. max_message_size: int = _default_max_message_size,
  100. compression_options: Optional[Dict[str, Any]] = None,
  101. ) -> None:
  102. self.ping_interval = ping_interval
  103. self.ping_timeout = ping_timeout
  104. self.max_message_size = max_message_size
  105. self.compression_options = compression_options
  106. class WebSocketHandler(tornado.web.RequestHandler):
  107. """Subclass this class to create a basic WebSocket handler.
  108. Override `on_message` to handle incoming messages, and use
  109. `write_message` to send messages to the client. You can also
  110. override `open` and `on_close` to handle opened and closed
  111. connections.
  112. Custom upgrade response headers can be sent by overriding
  113. `~tornado.web.RequestHandler.set_default_headers` or
  114. `~tornado.web.RequestHandler.prepare`.
  115. See http://dev.w3.org/html5/websockets/ for details on the
  116. JavaScript interface. The protocol is specified at
  117. http://tools.ietf.org/html/rfc6455.
  118. Here is an example WebSocket handler that echos back all received messages
  119. back to the client:
  120. .. testcode::
  121. class EchoWebSocket(tornado.websocket.WebSocketHandler):
  122. def open(self):
  123. print("WebSocket opened")
  124. def on_message(self, message):
  125. self.write_message(u"You said: " + message)
  126. def on_close(self):
  127. print("WebSocket closed")
  128. WebSockets are not standard HTTP connections. The "handshake" is
  129. HTTP, but after the handshake, the protocol is
  130. message-based. Consequently, most of the Tornado HTTP facilities
  131. are not available in handlers of this type. The only communication
  132. methods available to you are `write_message()`, `ping()`, and
  133. `close()`. Likewise, your request handler class should implement
  134. `open()` method rather than ``get()`` or ``post()``.
  135. If you map the handler above to ``/websocket`` in your application, you can
  136. invoke it in JavaScript with::
  137. var ws = new WebSocket("ws://localhost:8888/websocket");
  138. ws.onopen = function() {
  139. ws.send("Hello, world");
  140. };
  141. ws.onmessage = function (evt) {
  142. alert(evt.data);
  143. };
  144. This script pops up an alert box that says "You said: Hello, world".
  145. Web browsers allow any site to open a websocket connection to any other,
  146. instead of using the same-origin policy that governs other network
  147. access from JavaScript. This can be surprising and is a potential
  148. security hole, so since Tornado 4.0 `WebSocketHandler` requires
  149. applications that wish to receive cross-origin websockets to opt in
  150. by overriding the `~WebSocketHandler.check_origin` method (see that
  151. method's docs for details). Failure to do so is the most likely
  152. cause of 403 errors when making a websocket connection.
  153. When using a secure websocket connection (``wss://``) with a self-signed
  154. certificate, the connection from a browser may fail because it wants
  155. to show the "accept this certificate" dialog but has nowhere to show it.
  156. You must first visit a regular HTML page using the same certificate
  157. to accept it before the websocket connection will succeed.
  158. If the application setting ``websocket_ping_interval`` has a non-zero
  159. value, a ping will be sent periodically, and the connection will be
  160. closed if a response is not received before the ``websocket_ping_timeout``.
  161. Both settings are in seconds; floating point values are allowed.
  162. The default timeout is equal to the interval.
  163. Messages larger than the ``websocket_max_message_size`` application setting
  164. (default 10MiB) will not be accepted.
  165. .. versionchanged:: 4.5
  166. Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
  167. ``websocket_max_message_size``.
  168. """
  169. def __init__(
  170. self,
  171. application: tornado.web.Application,
  172. request: httputil.HTTPServerRequest,
  173. **kwargs: Any,
  174. ) -> None:
  175. super().__init__(application, request, **kwargs)
  176. self.ws_connection = None # type: Optional[WebSocketProtocol]
  177. self.close_code = None # type: Optional[int]
  178. self.close_reason = None # type: Optional[str]
  179. self._on_close_called = False
  180. async def get(self, *args: Any, **kwargs: Any) -> None:
  181. self.open_args = args
  182. self.open_kwargs = kwargs
  183. # Upgrade header should be present and should be equal to WebSocket
  184. if self.request.headers.get("Upgrade", "").lower() != "websocket":
  185. self.set_status(400)
  186. log_msg = 'Can "Upgrade" only to "WebSocket".'
  187. self.finish(log_msg)
  188. gen_log.debug(log_msg)
  189. return
  190. # Connection header should be upgrade.
  191. # Some proxy servers/load balancers
  192. # might mess with it.
  193. headers = self.request.headers
  194. connection = map(
  195. lambda s: s.strip().lower(), headers.get("Connection", "").split(",")
  196. )
  197. if "upgrade" not in connection:
  198. self.set_status(400)
  199. log_msg = '"Connection" must be "Upgrade".'
  200. self.finish(log_msg)
  201. gen_log.debug(log_msg)
  202. return
  203. # Handle WebSocket Origin naming convention differences
  204. # The difference between version 8 and 13 is that in 8 the
  205. # client sends a "Sec-Websocket-Origin" header and in 13 it's
  206. # simply "Origin".
  207. if "Origin" in self.request.headers:
  208. origin = self.request.headers.get("Origin")
  209. else:
  210. origin = self.request.headers.get("Sec-Websocket-Origin", None)
  211. # If there was an origin header, check to make sure it matches
  212. # according to check_origin. When the origin is None, we assume it
  213. # did not come from a browser and that it can be passed on.
  214. if origin is not None and not self.check_origin(origin):
  215. self.set_status(403)
  216. log_msg = "Cross origin websockets not allowed"
  217. self.finish(log_msg)
  218. gen_log.debug(log_msg)
  219. return
  220. self.ws_connection = self.get_websocket_protocol()
  221. if self.ws_connection:
  222. await self.ws_connection.accept_connection(self)
  223. else:
  224. self.set_status(426, "Upgrade Required")
  225. self.set_header("Sec-WebSocket-Version", "7, 8, 13")
  226. @property
  227. def ping_interval(self) -> Optional[float]:
  228. """The interval for sending websocket pings.
  229. If this is non-zero, the websocket will send a ping every
  230. ping_interval seconds.
  231. The client will respond with a "pong". The connection can be configured
  232. to timeout on late pong delivery using ``websocket_ping_timeout``.
  233. Set ``websocket_ping_interval = 0`` to disable pings.
  234. Default: ``0``
  235. """
  236. return self.settings.get("websocket_ping_interval", None)
  237. @property
  238. def ping_timeout(self) -> Optional[float]:
  239. """Timeout if no pong is received in this many seconds.
  240. To be used in combination with ``websocket_ping_interval > 0``.
  241. If a ping response (a "pong") is not received within
  242. ``websocket_ping_timeout`` seconds, then the websocket connection
  243. will be closed.
  244. This can help to clean up clients which have disconnected without
  245. cleanly closing the websocket connection.
  246. Note, the ping timeout cannot be longer than the ping interval.
  247. Set ``websocket_ping_timeout = 0`` to disable the ping timeout.
  248. Default: equal to the ``ping_interval``.
  249. .. versionchanged:: 6.5.0
  250. Default changed from the max of 3 pings or 30 seconds.
  251. The ping timeout can no longer be configured longer than the
  252. ping interval.
  253. """
  254. return self.settings.get("websocket_ping_timeout", None)
  255. @property
  256. def max_message_size(self) -> int:
  257. """Maximum allowed message size.
  258. If the remote peer sends a message larger than this, the connection
  259. will be closed.
  260. Default is 10MiB.
  261. """
  262. return self.settings.get(
  263. "websocket_max_message_size", _default_max_message_size
  264. )
  265. def write_message(
  266. self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False
  267. ) -> "Future[None]":
  268. """Sends the given message to the client of this Web Socket.
  269. The message may be either a string or a dict (which will be
  270. encoded as json). If the ``binary`` argument is false, the
  271. message will be sent as utf8; in binary mode any byte string
  272. is allowed.
  273. If the connection is already closed, raises `WebSocketClosedError`.
  274. Returns a `.Future` which can be used for flow control.
  275. .. versionchanged:: 3.2
  276. `WebSocketClosedError` was added (previously a closed connection
  277. would raise an `AttributeError`)
  278. .. versionchanged:: 4.3
  279. Returns a `.Future` which can be used for flow control.
  280. .. versionchanged:: 5.0
  281. Consistently raises `WebSocketClosedError`. Previously could
  282. sometimes raise `.StreamClosedError`.
  283. """
  284. if self.ws_connection is None or self.ws_connection.is_closing():
  285. raise WebSocketClosedError()
  286. if isinstance(message, dict):
  287. message = tornado.escape.json_encode(message)
  288. return self.ws_connection.write_message(message, binary=binary)
  289. def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
  290. """Override to implement subprotocol negotiation.
  291. ``subprotocols`` is a list of strings identifying the
  292. subprotocols proposed by the client. This method may be
  293. overridden to return one of those strings to select it, or
  294. ``None`` to not select a subprotocol.
  295. Failure to select a subprotocol does not automatically abort
  296. the connection, although clients may close the connection if
  297. none of their proposed subprotocols was selected.
  298. The list may be empty, in which case this method must return
  299. None. This method is always called exactly once even if no
  300. subprotocols were proposed so that the handler can be advised
  301. of this fact.
  302. .. versionchanged:: 5.1
  303. Previously, this method was called with a list containing
  304. an empty string instead of an empty list if no subprotocols
  305. were proposed by the client.
  306. """
  307. return None
  308. @property
  309. def selected_subprotocol(self) -> Optional[str]:
  310. """The subprotocol returned by `select_subprotocol`.
  311. .. versionadded:: 5.1
  312. """
  313. assert self.ws_connection is not None
  314. return self.ws_connection.selected_subprotocol
  315. def get_compression_options(self) -> Optional[Dict[str, Any]]:
  316. """Override to return compression options for the connection.
  317. If this method returns None (the default), compression will
  318. be disabled. If it returns a dict (even an empty one), it
  319. will be enabled. The contents of the dict may be used to
  320. control the following compression options:
  321. ``compression_level`` specifies the compression level.
  322. ``mem_level`` specifies the amount of memory used for the internal compression state.
  323. These parameters are documented in detail here:
  324. https://docs.python.org/3.13/library/zlib.html#zlib.compressobj
  325. .. versionadded:: 4.1
  326. .. versionchanged:: 4.5
  327. Added ``compression_level`` and ``mem_level``.
  328. """
  329. # TODO: Add wbits option.
  330. return None
  331. def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
  332. """Invoked when a new WebSocket is opened.
  333. The arguments to `open` are extracted from the `tornado.web.URLSpec`
  334. regular expression, just like the arguments to
  335. `tornado.web.RequestHandler.get`.
  336. `open` may be a coroutine. `on_message` will not be called until
  337. `open` has returned.
  338. .. versionchanged:: 5.1
  339. ``open`` may be a coroutine.
  340. """
  341. pass
  342. def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
  343. """Handle incoming messages on the WebSocket
  344. This method must be overridden.
  345. .. versionchanged:: 4.5
  346. ``on_message`` can be a coroutine.
  347. """
  348. raise NotImplementedError
  349. def ping(self, data: Union[str, bytes] = b"") -> None:
  350. """Send ping frame to the remote end.
  351. The data argument allows a small amount of data (up to 125
  352. bytes) to be sent as a part of the ping message. Note that not
  353. all websocket implementations expose this data to
  354. applications.
  355. Consider using the ``websocket_ping_interval`` application
  356. setting instead of sending pings manually.
  357. .. versionchanged:: 5.1
  358. The data argument is now optional.
  359. """
  360. data = utf8(data)
  361. if self.ws_connection is None or self.ws_connection.is_closing():
  362. raise WebSocketClosedError()
  363. self.ws_connection.write_ping(data)
  364. def on_pong(self, data: bytes) -> None:
  365. """Invoked when the response to a ping frame is received."""
  366. pass
  367. def on_ping(self, data: bytes) -> None:
  368. """Invoked when the a ping frame is received."""
  369. pass
  370. def on_close(self) -> None:
  371. """Invoked when the WebSocket is closed.
  372. If the connection was closed cleanly and a status code or reason
  373. phrase was supplied, these values will be available as the attributes
  374. ``self.close_code`` and ``self.close_reason``.
  375. .. versionchanged:: 4.0
  376. Added ``close_code`` and ``close_reason`` attributes.
  377. """
  378. pass
  379. def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
  380. """Closes this Web Socket.
  381. Once the close handshake is successful the socket will be closed.
  382. ``code`` may be a numeric status code, taken from the values
  383. defined in `RFC 6455 section 7.4.1
  384. <https://tools.ietf.org/html/rfc6455#section-7.4.1>`_.
  385. ``reason`` may be a textual message about why the connection is
  386. closing. These values are made available to the client, but are
  387. not otherwise interpreted by the websocket protocol.
  388. .. versionchanged:: 4.0
  389. Added the ``code`` and ``reason`` arguments.
  390. """
  391. if self.ws_connection:
  392. self.ws_connection.close(code, reason)
  393. self.ws_connection = None
  394. def check_origin(self, origin: str) -> bool:
  395. """Override to enable support for allowing alternate origins.
  396. The ``origin`` argument is the value of the ``Origin`` HTTP
  397. header, the url responsible for initiating this request. This
  398. method is not called for clients that do not send this header;
  399. such requests are always allowed (because all browsers that
  400. implement WebSockets support this header, and non-browser
  401. clients do not have the same cross-site security concerns).
  402. Should return ``True`` to accept the request or ``False`` to
  403. reject it. By default, rejects all requests with an origin on
  404. a host other than this one.
  405. This is a security protection against cross site scripting attacks on
  406. browsers, since WebSockets are allowed to bypass the usual same-origin
  407. policies and don't use CORS headers.
  408. .. warning::
  409. This is an important security measure; don't disable it
  410. without understanding the security implications. In
  411. particular, if your authentication is cookie-based, you
  412. must either restrict the origins allowed by
  413. ``check_origin()`` or implement your own XSRF-like
  414. protection for websocket connections. See `these
  415. <https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_
  416. `articles
  417. <https://devcenter.heroku.com/articles/websocket-security>`_
  418. for more.
  419. To accept all cross-origin traffic (which was the default prior to
  420. Tornado 4.0), simply override this method to always return ``True``::
  421. def check_origin(self, origin):
  422. return True
  423. To allow connections from any subdomain of your site, you might
  424. do something like::
  425. def check_origin(self, origin):
  426. parsed_origin = urllib.parse.urlparse(origin)
  427. return parsed_origin.netloc.endswith(".mydomain.com")
  428. .. versionadded:: 4.0
  429. """
  430. parsed_origin = urlparse(origin)
  431. origin = parsed_origin.netloc
  432. origin = origin.lower()
  433. host = self.request.headers.get("Host")
  434. # Check to see that origin matches host directly, including ports
  435. return origin == host
  436. def set_nodelay(self, value: bool) -> None:
  437. """Set the no-delay flag for this stream.
  438. By default, small messages may be delayed and/or combined to minimize
  439. the number of packets sent. This can sometimes cause 200-500ms delays
  440. due to the interaction between Nagle's algorithm and TCP delayed
  441. ACKs. To reduce this delay (at the expense of possibly increasing
  442. bandwidth usage), call ``self.set_nodelay(True)`` once the websocket
  443. connection is established.
  444. See `.BaseIOStream.set_nodelay` for additional details.
  445. .. versionadded:: 3.1
  446. """
  447. assert self.ws_connection is not None
  448. self.ws_connection.set_nodelay(value)
  449. def on_connection_close(self) -> None:
  450. if self.ws_connection:
  451. self.ws_connection.on_connection_close()
  452. self.ws_connection = None
  453. if not self._on_close_called:
  454. self._on_close_called = True
  455. self.on_close()
  456. self._break_cycles()
  457. def on_ws_connection_close(
  458. self, close_code: Optional[int] = None, close_reason: Optional[str] = None
  459. ) -> None:
  460. self.close_code = close_code
  461. self.close_reason = close_reason
  462. self.on_connection_close()
  463. def _break_cycles(self) -> None:
  464. # WebSocketHandlers call finish() early, but we don't want to
  465. # break up reference cycles (which makes it impossible to call
  466. # self.render_string) until after we've really closed the
  467. # connection (if it was established in the first place,
  468. # indicated by status code 101).
  469. if self.get_status() != 101 or self._on_close_called:
  470. super()._break_cycles()
  471. def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]:
  472. websocket_version = self.request.headers.get("Sec-WebSocket-Version")
  473. if websocket_version in ("7", "8", "13"):
  474. params = _WebSocketParams(
  475. ping_interval=self.ping_interval,
  476. ping_timeout=self.ping_timeout,
  477. max_message_size=self.max_message_size,
  478. compression_options=self.get_compression_options(),
  479. )
  480. return WebSocketProtocol13(self, False, params)
  481. return None
  482. def _detach_stream(self) -> IOStream:
  483. # disable non-WS methods
  484. for method in [
  485. "write",
  486. "redirect",
  487. "set_header",
  488. "set_cookie",
  489. "set_status",
  490. "flush",
  491. "finish",
  492. ]:
  493. setattr(self, method, _raise_not_supported_for_websockets)
  494. return self.detach()
  495. def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None:
  496. raise RuntimeError("Method not supported for Web Sockets")
  497. class WebSocketProtocol(abc.ABC):
  498. """Base class for WebSocket protocol versions."""
  499. def __init__(self, handler: "_WebSocketDelegate") -> None:
  500. self.handler = handler
  501. self.stream = None # type: Optional[IOStream]
  502. self.client_terminated = False
  503. self.server_terminated = False
  504. def _run_callback(
  505. self, callback: Callable, *args: Any, **kwargs: Any
  506. ) -> "Optional[Future[Any]]":
  507. """Runs the given callback with exception handling.
  508. If the callback is a coroutine, returns its Future. On error, aborts the
  509. websocket connection and returns None.
  510. """
  511. try:
  512. result = callback(*args, **kwargs)
  513. except Exception:
  514. self.handler.log_exception(*sys.exc_info())
  515. self._abort()
  516. return None
  517. else:
  518. if result is not None:
  519. result = gen.convert_yielded(result)
  520. assert self.stream is not None
  521. self.stream.io_loop.add_future(result, lambda f: f.result())
  522. return result
  523. def on_connection_close(self) -> None:
  524. self._abort()
  525. def _abort(self) -> None:
  526. """Instantly aborts the WebSocket connection by closing the socket"""
  527. self.client_terminated = True
  528. self.server_terminated = True
  529. if self.stream is not None:
  530. self.stream.close() # forcibly tear down the connection
  531. self.close() # let the subclass cleanup
  532. @abc.abstractmethod
  533. def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
  534. raise NotImplementedError()
  535. @abc.abstractmethod
  536. def is_closing(self) -> bool:
  537. raise NotImplementedError()
  538. @abc.abstractmethod
  539. async def accept_connection(self, handler: WebSocketHandler) -> None:
  540. raise NotImplementedError()
  541. @abc.abstractmethod
  542. def write_message(
  543. self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
  544. ) -> "Future[None]":
  545. raise NotImplementedError()
  546. @property
  547. @abc.abstractmethod
  548. def selected_subprotocol(self) -> Optional[str]:
  549. raise NotImplementedError()
  550. @abc.abstractmethod
  551. def write_ping(self, data: bytes) -> None:
  552. raise NotImplementedError()
  553. # The entry points below are used by WebSocketClientConnection,
  554. # which was introduced after we only supported a single version of
  555. # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13
  556. # boundary is currently pretty ad-hoc.
  557. @abc.abstractmethod
  558. def _process_server_headers(
  559. self, key: Union[str, bytes], headers: httputil.HTTPHeaders
  560. ) -> None:
  561. raise NotImplementedError()
  562. @abc.abstractmethod
  563. def start_pinging(self) -> None:
  564. raise NotImplementedError()
  565. @abc.abstractmethod
  566. async def _receive_frame_loop(self) -> None:
  567. raise NotImplementedError()
  568. @abc.abstractmethod
  569. def set_nodelay(self, x: bool) -> None:
  570. raise NotImplementedError()
  571. class _PerMessageDeflateCompressor:
  572. def __init__(
  573. self,
  574. persistent: bool,
  575. max_wbits: Optional[int],
  576. compression_options: Optional[Dict[str, Any]] = None,
  577. ) -> None:
  578. if max_wbits is None:
  579. max_wbits = zlib.MAX_WBITS
  580. # There is no symbolic constant for the minimum wbits value.
  581. if not (8 <= max_wbits <= zlib.MAX_WBITS):
  582. raise ValueError(
  583. "Invalid max_wbits value %r; allowed range 8-%d",
  584. max_wbits,
  585. zlib.MAX_WBITS,
  586. )
  587. self._max_wbits = max_wbits
  588. if (
  589. compression_options is None
  590. or "compression_level" not in compression_options
  591. ):
  592. self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL
  593. else:
  594. self._compression_level = compression_options["compression_level"]
  595. if compression_options is None or "mem_level" not in compression_options:
  596. self._mem_level = 8
  597. else:
  598. self._mem_level = compression_options["mem_level"]
  599. if persistent:
  600. self._compressor = self._create_compressor() # type: Optional[_Compressor]
  601. else:
  602. self._compressor = None
  603. def _create_compressor(self) -> "_Compressor":
  604. return zlib.compressobj(
  605. self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level
  606. )
  607. def compress(self, data: bytes) -> bytes:
  608. compressor = self._compressor or self._create_compressor()
  609. data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH)
  610. assert data.endswith(b"\x00\x00\xff\xff")
  611. return data[:-4]
  612. class _PerMessageDeflateDecompressor:
  613. def __init__(
  614. self,
  615. persistent: bool,
  616. max_wbits: Optional[int],
  617. max_message_size: int,
  618. compression_options: Optional[Dict[str, Any]] = None,
  619. ) -> None:
  620. self._max_message_size = max_message_size
  621. if max_wbits is None:
  622. max_wbits = zlib.MAX_WBITS
  623. if not (8 <= max_wbits <= zlib.MAX_WBITS):
  624. raise ValueError(
  625. "Invalid max_wbits value %r; allowed range 8-%d",
  626. max_wbits,
  627. zlib.MAX_WBITS,
  628. )
  629. self._max_wbits = max_wbits
  630. if persistent:
  631. self._decompressor = (
  632. self._create_decompressor()
  633. ) # type: Optional[_Decompressor]
  634. else:
  635. self._decompressor = None
  636. def _create_decompressor(self) -> "_Decompressor":
  637. return zlib.decompressobj(-self._max_wbits)
  638. def decompress(self, data: bytes) -> bytes:
  639. decompressor = self._decompressor or self._create_decompressor()
  640. result = decompressor.decompress(
  641. data + b"\x00\x00\xff\xff", self._max_message_size
  642. )
  643. if decompressor.unconsumed_tail:
  644. raise _DecompressTooLargeError()
  645. return result
  646. class WebSocketProtocol13(WebSocketProtocol):
  647. """Implementation of the WebSocket protocol from RFC 6455.
  648. This class supports versions 7 and 8 of the protocol in addition to the
  649. final version 13.
  650. """
  651. # Bit masks for the first byte of a frame.
  652. FIN = 0x80
  653. RSV1 = 0x40
  654. RSV2 = 0x20
  655. RSV3 = 0x10
  656. RSV_MASK = RSV1 | RSV2 | RSV3
  657. OPCODE_MASK = 0x0F
  658. stream = None # type: IOStream
  659. def __init__(
  660. self,
  661. handler: "_WebSocketDelegate",
  662. mask_outgoing: bool,
  663. params: _WebSocketParams,
  664. ) -> None:
  665. WebSocketProtocol.__init__(self, handler)
  666. self.mask_outgoing = mask_outgoing
  667. self.params = params
  668. self._final_frame = False
  669. self._frame_opcode = None
  670. self._masked_frame = None
  671. self._frame_mask = None # type: Optional[bytes]
  672. self._frame_length = None
  673. self._fragmented_message_buffer = None # type: Optional[bytearray]
  674. self._fragmented_message_opcode = None
  675. self._waiting = None # type: object
  676. self._compression_options = params.compression_options
  677. self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor]
  678. self._compressor = None # type: Optional[_PerMessageDeflateCompressor]
  679. self._frame_compressed = None # type: Optional[bool]
  680. # The total uncompressed size of all messages received or sent.
  681. # Unicode messages are encoded to utf8.
  682. # Only for testing; subject to change.
  683. self._message_bytes_in = 0
  684. self._message_bytes_out = 0
  685. # The total size of all packets received or sent. Includes
  686. # the effect of compression, frame overhead, and control frames.
  687. self._wire_bytes_in = 0
  688. self._wire_bytes_out = 0
  689. self._received_pong = False # type: bool
  690. self.close_code = None # type: Optional[int]
  691. self.close_reason = None # type: Optional[str]
  692. self._ping_coroutine = None # type: Optional[asyncio.Task]
  693. # Use a property for this to satisfy the abc.
  694. @property
  695. def selected_subprotocol(self) -> Optional[str]:
  696. return self._selected_subprotocol
  697. @selected_subprotocol.setter
  698. def selected_subprotocol(self, value: Optional[str]) -> None:
  699. self._selected_subprotocol = value
  700. async def accept_connection(self, handler: WebSocketHandler) -> None:
  701. try:
  702. self._handle_websocket_headers(handler)
  703. except ValueError:
  704. handler.set_status(400)
  705. log_msg = "Missing/Invalid WebSocket headers"
  706. handler.finish(log_msg)
  707. gen_log.debug(log_msg)
  708. return
  709. try:
  710. await self._accept_connection(handler)
  711. except asyncio.CancelledError:
  712. self._abort()
  713. return
  714. except ValueError:
  715. gen_log.debug("Malformed WebSocket request received", exc_info=True)
  716. self._abort()
  717. return
  718. def _handle_websocket_headers(self, handler: WebSocketHandler) -> None:
  719. """Verifies all invariant- and required headers
  720. If a header is missing or have an incorrect value ValueError will be
  721. raised
  722. """
  723. fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
  724. if not all(map(lambda f: handler.request.headers.get(f), fields)):
  725. raise ValueError("Missing/Invalid WebSocket headers")
  726. @staticmethod
  727. def compute_accept_value(key: Union[str, bytes]) -> str:
  728. """Computes the value for the Sec-WebSocket-Accept header,
  729. given the value for Sec-WebSocket-Key.
  730. """
  731. sha1 = hashlib.sha1()
  732. sha1.update(utf8(key))
  733. sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
  734. return native_str(base64.b64encode(sha1.digest()))
  735. def _challenge_response(self, handler: WebSocketHandler) -> str:
  736. return WebSocketProtocol13.compute_accept_value(
  737. cast(str, handler.request.headers.get("Sec-Websocket-Key"))
  738. )
  739. async def _accept_connection(self, handler: WebSocketHandler) -> None:
  740. subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
  741. if subprotocol_header:
  742. subprotocols = [s.strip() for s in subprotocol_header.split(",")]
  743. else:
  744. subprotocols = []
  745. self.selected_subprotocol = handler.select_subprotocol(subprotocols)
  746. if self.selected_subprotocol:
  747. assert self.selected_subprotocol in subprotocols
  748. handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
  749. extensions = self._parse_extensions_header(handler.request.headers)
  750. for ext in extensions:
  751. if ext[0] == "permessage-deflate" and self._compression_options is not None:
  752. # TODO: negotiate parameters if compression_options
  753. # specifies limits.
  754. self._create_compressors("server", ext[1], self._compression_options)
  755. if (
  756. "client_max_window_bits" in ext[1]
  757. and ext[1]["client_max_window_bits"] is None
  758. ):
  759. # Don't echo an offered client_max_window_bits
  760. # parameter with no value.
  761. del ext[1]["client_max_window_bits"]
  762. handler.set_header(
  763. "Sec-WebSocket-Extensions",
  764. httputil._encode_header("permessage-deflate", ext[1]),
  765. )
  766. break
  767. handler.clear_header("Content-Type")
  768. handler.set_status(101)
  769. handler.set_header("Upgrade", "websocket")
  770. handler.set_header("Connection", "Upgrade")
  771. handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler))
  772. handler.finish()
  773. self.stream = handler._detach_stream()
  774. self.start_pinging()
  775. try:
  776. open_result = handler.open(*handler.open_args, **handler.open_kwargs)
  777. if open_result is not None:
  778. await open_result
  779. except Exception:
  780. handler.log_exception(*sys.exc_info())
  781. self._abort()
  782. return
  783. await self._receive_frame_loop()
  784. def _parse_extensions_header(
  785. self, headers: httputil.HTTPHeaders
  786. ) -> List[Tuple[str, Dict[str, str]]]:
  787. extensions = headers.get("Sec-WebSocket-Extensions", "")
  788. if extensions:
  789. return [httputil._parse_header(e.strip()) for e in extensions.split(",")]
  790. return []
  791. def _process_server_headers(
  792. self, key: Union[str, bytes], headers: httputil.HTTPHeaders
  793. ) -> None:
  794. """Process the headers sent by the server to this client connection.
  795. 'key' is the websocket handshake challenge/response key.
  796. """
  797. assert headers["Upgrade"].lower() == "websocket"
  798. assert headers["Connection"].lower() == "upgrade"
  799. accept = self.compute_accept_value(key)
  800. assert headers["Sec-Websocket-Accept"] == accept
  801. extensions = self._parse_extensions_header(headers)
  802. for ext in extensions:
  803. if ext[0] == "permessage-deflate" and self._compression_options is not None:
  804. self._create_compressors("client", ext[1])
  805. else:
  806. raise ValueError("unsupported extension %r", ext)
  807. self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None)
  808. def _get_compressor_options(
  809. self,
  810. side: str,
  811. agreed_parameters: Dict[str, Any],
  812. compression_options: Optional[Dict[str, Any]] = None,
  813. ) -> Dict[str, Any]:
  814. """Converts a websocket agreed_parameters set to keyword arguments
  815. for our compressor objects.
  816. """
  817. options = dict(
  818. persistent=(side + "_no_context_takeover") not in agreed_parameters
  819. ) # type: Dict[str, Any]
  820. wbits_header = agreed_parameters.get(side + "_max_window_bits", None)
  821. if wbits_header is None:
  822. options["max_wbits"] = zlib.MAX_WBITS
  823. else:
  824. options["max_wbits"] = int(wbits_header)
  825. options["compression_options"] = compression_options
  826. return options
  827. def _create_compressors(
  828. self,
  829. side: str,
  830. agreed_parameters: Dict[str, Any],
  831. compression_options: Optional[Dict[str, Any]] = None,
  832. ) -> None:
  833. # TODO: handle invalid parameters gracefully
  834. allowed_keys = {
  835. "server_no_context_takeover",
  836. "client_no_context_takeover",
  837. "server_max_window_bits",
  838. "client_max_window_bits",
  839. }
  840. for key in agreed_parameters:
  841. if key not in allowed_keys:
  842. raise ValueError("unsupported compression parameter %r" % key)
  843. other_side = "client" if (side == "server") else "server"
  844. self._compressor = _PerMessageDeflateCompressor(
  845. **self._get_compressor_options(side, agreed_parameters, compression_options)
  846. )
  847. self._decompressor = _PerMessageDeflateDecompressor(
  848. max_message_size=self.params.max_message_size,
  849. **self._get_compressor_options(
  850. other_side, agreed_parameters, compression_options
  851. ),
  852. )
  853. def _write_frame(
  854. self, fin: bool, opcode: int, data: bytes, flags: int = 0
  855. ) -> "Future[None]":
  856. data_len = len(data)
  857. if opcode & 0x8:
  858. # All control frames MUST have a payload length of 125
  859. # bytes or less and MUST NOT be fragmented.
  860. if not fin:
  861. raise ValueError("control frames may not be fragmented")
  862. if data_len > 125:
  863. raise ValueError("control frame payloads may not exceed 125 bytes")
  864. if fin:
  865. finbit = self.FIN
  866. else:
  867. finbit = 0
  868. frame = struct.pack("B", finbit | opcode | flags)
  869. if self.mask_outgoing:
  870. mask_bit = 0x80
  871. else:
  872. mask_bit = 0
  873. if data_len < 126:
  874. frame += struct.pack("B", data_len | mask_bit)
  875. elif data_len <= 0xFFFF:
  876. frame += struct.pack("!BH", 126 | mask_bit, data_len)
  877. else:
  878. frame += struct.pack("!BQ", 127 | mask_bit, data_len)
  879. if self.mask_outgoing:
  880. mask = os.urandom(4)
  881. data = mask + _websocket_mask(mask, data)
  882. frame += data
  883. self._wire_bytes_out += len(frame)
  884. return self.stream.write(frame)
  885. def write_message(
  886. self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
  887. ) -> "Future[None]":
  888. """Sends the given message to the client of this Web Socket."""
  889. if binary:
  890. opcode = 0x2
  891. else:
  892. opcode = 0x1
  893. if isinstance(message, dict):
  894. message = tornado.escape.json_encode(message)
  895. message = tornado.escape.utf8(message)
  896. assert isinstance(message, bytes)
  897. self._message_bytes_out += len(message)
  898. flags = 0
  899. if self._compressor:
  900. message = self._compressor.compress(message)
  901. flags |= self.RSV1
  902. # For historical reasons, write methods in Tornado operate in a semi-synchronous
  903. # mode in which awaiting the Future they return is optional (But errors can
  904. # still be raised). This requires us to go through an awkward dance here
  905. # to transform the errors that may be returned while presenting the same
  906. # semi-synchronous interface.
  907. try:
  908. fut = self._write_frame(True, opcode, message, flags=flags)
  909. except StreamClosedError:
  910. raise WebSocketClosedError()
  911. async def wrapper() -> None:
  912. try:
  913. await fut
  914. except StreamClosedError:
  915. raise WebSocketClosedError()
  916. return asyncio.ensure_future(wrapper())
  917. def write_ping(self, data: bytes) -> None:
  918. """Send ping frame."""
  919. assert isinstance(data, bytes)
  920. self._write_frame(True, 0x9, data)
  921. async def _receive_frame_loop(self) -> None:
  922. try:
  923. while not self.client_terminated:
  924. await self._receive_frame()
  925. except StreamClosedError:
  926. self._abort()
  927. self.handler.on_ws_connection_close(self.close_code, self.close_reason)
  928. async def _read_bytes(self, n: int) -> bytes:
  929. data = await self.stream.read_bytes(n)
  930. self._wire_bytes_in += n
  931. return data
  932. async def _receive_frame(self) -> None:
  933. # Read the frame header.
  934. data = await self._read_bytes(2)
  935. header, mask_payloadlen = struct.unpack("BB", data)
  936. is_final_frame = header & self.FIN
  937. reserved_bits = header & self.RSV_MASK
  938. opcode = header & self.OPCODE_MASK
  939. opcode_is_control = opcode & 0x8
  940. if self._decompressor is not None and opcode != 0:
  941. # Compression flag is present in the first frame's header,
  942. # but we can't decompress until we have all the frames of
  943. # the message.
  944. self._frame_compressed = bool(reserved_bits & self.RSV1)
  945. reserved_bits &= ~self.RSV1
  946. if reserved_bits:
  947. # client is using as-yet-undefined extensions; abort
  948. self._abort()
  949. return
  950. is_masked = bool(mask_payloadlen & 0x80)
  951. payloadlen = mask_payloadlen & 0x7F
  952. # Parse and validate the length.
  953. if opcode_is_control and payloadlen >= 126:
  954. # control frames must have payload < 126
  955. self._abort()
  956. return
  957. if payloadlen < 126:
  958. self._frame_length = payloadlen
  959. elif payloadlen == 126:
  960. data = await self._read_bytes(2)
  961. payloadlen = struct.unpack("!H", data)[0]
  962. elif payloadlen == 127:
  963. data = await self._read_bytes(8)
  964. payloadlen = struct.unpack("!Q", data)[0]
  965. new_len = payloadlen
  966. if self._fragmented_message_buffer is not None:
  967. new_len += len(self._fragmented_message_buffer)
  968. if new_len > self.params.max_message_size:
  969. self.close(1009, "message too big")
  970. self._abort()
  971. return
  972. # Read the payload, unmasking if necessary.
  973. if is_masked:
  974. self._frame_mask = await self._read_bytes(4)
  975. data = await self._read_bytes(payloadlen)
  976. if is_masked:
  977. assert self._frame_mask is not None
  978. data = _websocket_mask(self._frame_mask, data)
  979. # Decide what to do with this frame.
  980. if opcode_is_control:
  981. # control frames may be interleaved with a series of fragmented
  982. # data frames, so control frames must not interact with
  983. # self._fragmented_*
  984. if not is_final_frame:
  985. # control frames must not be fragmented
  986. self._abort()
  987. return
  988. elif opcode == 0: # continuation frame
  989. if self._fragmented_message_buffer is None:
  990. # nothing to continue
  991. self._abort()
  992. return
  993. self._fragmented_message_buffer.extend(data)
  994. if is_final_frame:
  995. opcode = self._fragmented_message_opcode
  996. data = bytes(self._fragmented_message_buffer)
  997. self._fragmented_message_buffer = None
  998. else: # start of new data message
  999. if self._fragmented_message_buffer is not None:
  1000. # can't start new message until the old one is finished
  1001. self._abort()
  1002. return
  1003. if not is_final_frame:
  1004. self._fragmented_message_opcode = opcode
  1005. self._fragmented_message_buffer = bytearray(data)
  1006. if is_final_frame:
  1007. handled_future = self._handle_message(opcode, data)
  1008. if handled_future is not None:
  1009. await handled_future
  1010. def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]":
  1011. """Execute on_message, returning its Future if it is a coroutine."""
  1012. if self.client_terminated:
  1013. return None
  1014. if self._frame_compressed:
  1015. assert self._decompressor is not None
  1016. try:
  1017. data = self._decompressor.decompress(data)
  1018. except _DecompressTooLargeError:
  1019. self.close(1009, "message too big after decompression")
  1020. self._abort()
  1021. return None
  1022. if opcode == 0x1:
  1023. # UTF-8 data
  1024. self._message_bytes_in += len(data)
  1025. try:
  1026. decoded = data.decode("utf-8")
  1027. except UnicodeDecodeError:
  1028. self._abort()
  1029. return None
  1030. return self._run_callback(self.handler.on_message, decoded)
  1031. elif opcode == 0x2:
  1032. # Binary data
  1033. self._message_bytes_in += len(data)
  1034. return self._run_callback(self.handler.on_message, data)
  1035. elif opcode == 0x8:
  1036. # Close
  1037. self.client_terminated = True
  1038. if len(data) >= 2:
  1039. self.close_code = struct.unpack(">H", data[:2])[0]
  1040. if len(data) > 2:
  1041. self.close_reason = to_unicode(data[2:])
  1042. # Echo the received close code, if any (RFC 6455 section 5.5.1).
  1043. self.close(self.close_code)
  1044. elif opcode == 0x9:
  1045. # Ping
  1046. try:
  1047. self._write_frame(True, 0xA, data)
  1048. except StreamClosedError:
  1049. self._abort()
  1050. self._run_callback(self.handler.on_ping, data)
  1051. elif opcode == 0xA:
  1052. # Pong
  1053. self._received_pong = True
  1054. return self._run_callback(self.handler.on_pong, data)
  1055. else:
  1056. self._abort()
  1057. return None
  1058. def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
  1059. """Closes the WebSocket connection."""
  1060. if not self.server_terminated:
  1061. if not self.stream.closed():
  1062. if code is None and reason is not None:
  1063. code = 1000 # "normal closure" status code
  1064. if code is None:
  1065. close_data = b""
  1066. else:
  1067. close_data = struct.pack(">H", code)
  1068. if reason is not None:
  1069. close_data += utf8(reason)
  1070. try:
  1071. self._write_frame(True, 0x8, close_data)
  1072. except StreamClosedError:
  1073. self._abort()
  1074. self.server_terminated = True
  1075. if self.client_terminated:
  1076. if self._waiting is not None:
  1077. self.stream.io_loop.remove_timeout(self._waiting)
  1078. self._waiting = None
  1079. self.stream.close()
  1080. elif self._waiting is None:
  1081. # Give the client a few seconds to complete a clean shutdown,
  1082. # otherwise just close the connection.
  1083. self._waiting = self.stream.io_loop.add_timeout(
  1084. self.stream.io_loop.time() + 5, self._abort
  1085. )
  1086. if self._ping_coroutine:
  1087. self._ping_coroutine.cancel()
  1088. self._ping_coroutine = None
  1089. def is_closing(self) -> bool:
  1090. """Return ``True`` if this connection is closing.
  1091. The connection is considered closing if either side has
  1092. initiated its closing handshake or if the stream has been
  1093. shut down uncleanly.
  1094. """
  1095. return self.stream.closed() or self.client_terminated or self.server_terminated
  1096. def set_nodelay(self, x: bool) -> None:
  1097. self.stream.set_nodelay(x)
  1098. @property
  1099. def ping_interval(self) -> float:
  1100. interval = self.params.ping_interval
  1101. if interval is not None:
  1102. return interval
  1103. return 0
  1104. @property
  1105. def ping_timeout(self) -> float:
  1106. timeout = self.params.ping_timeout
  1107. if timeout is not None:
  1108. if self.ping_interval and timeout > self.ping_interval:
  1109. de_dupe_gen_log(
  1110. # Note: using de_dupe_gen_log to prevent this message from
  1111. # being duplicated for each connection
  1112. logging.WARNING,
  1113. f"The websocket_ping_timeout ({timeout}) cannot be longer"
  1114. f" than the websocket_ping_interval ({self.ping_interval})."
  1115. f"\nSetting websocket_ping_timeout={self.ping_interval}",
  1116. )
  1117. return self.ping_interval
  1118. return timeout
  1119. return self.ping_interval
  1120. def start_pinging(self) -> None:
  1121. """Start sending periodic pings to keep the connection alive"""
  1122. if (
  1123. # prevent multiple ping coroutines being run in parallel
  1124. not self._ping_coroutine
  1125. # only run the ping coroutine if a ping interval is configured
  1126. and self.ping_interval > 0
  1127. ):
  1128. self._ping_coroutine = asyncio.create_task(self.periodic_ping())
  1129. @staticmethod
  1130. def ping_sleep_time(*, last_ping_time: float, interval: float, now: float) -> float:
  1131. """Calculate the sleep time until the next ping should be sent."""
  1132. return max(0, last_ping_time + interval - now)
  1133. async def periodic_ping(self) -> None:
  1134. """Send a ping and wait for a pong if ping_timeout is configured.
  1135. Called periodically if the websocket_ping_interval is set and non-zero.
  1136. """
  1137. interval = self.ping_interval
  1138. timeout = self.ping_timeout
  1139. await asyncio.sleep(interval)
  1140. while True:
  1141. # send a ping
  1142. self._received_pong = False
  1143. ping_time = IOLoop.current().time()
  1144. self.write_ping(b"")
  1145. # wait until the ping timeout
  1146. await asyncio.sleep(timeout)
  1147. # make sure we received a pong within the timeout
  1148. if timeout > 0 and not self._received_pong:
  1149. self.close(reason="ping timed out")
  1150. return
  1151. # wait until the next scheduled ping
  1152. await asyncio.sleep(
  1153. self.ping_sleep_time(
  1154. last_ping_time=ping_time,
  1155. interval=interval,
  1156. now=IOLoop.current().time(),
  1157. )
  1158. )
  1159. class WebSocketClientConnection(simple_httpclient._HTTPConnection):
  1160. """WebSocket client connection.
  1161. This class should not be instantiated directly; use the
  1162. `websocket_connect` function instead.
  1163. """
  1164. protocol = None # type: WebSocketProtocol
  1165. def __init__(
  1166. self,
  1167. request: httpclient.HTTPRequest,
  1168. on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None,
  1169. compression_options: Optional[Dict[str, Any]] = None,
  1170. ping_interval: Optional[float] = None,
  1171. ping_timeout: Optional[float] = None,
  1172. max_message_size: int = _default_max_message_size,
  1173. subprotocols: Optional[List[str]] = None,
  1174. resolver: Optional[Resolver] = None,
  1175. ) -> None:
  1176. self.connect_future = Future() # type: Future[WebSocketClientConnection]
  1177. self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]]
  1178. self.key = base64.b64encode(os.urandom(16))
  1179. self._on_message_callback = on_message_callback
  1180. self.close_code = None # type: Optional[int]
  1181. self.close_reason = None # type: Optional[str]
  1182. self.params = _WebSocketParams(
  1183. ping_interval=ping_interval,
  1184. ping_timeout=ping_timeout,
  1185. max_message_size=max_message_size,
  1186. compression_options=compression_options,
  1187. )
  1188. scheme, sep, rest = request.url.partition(":")
  1189. scheme = {"ws": "http", "wss": "https"}[scheme]
  1190. request.url = scheme + sep + rest
  1191. request.headers.update(
  1192. {
  1193. "Upgrade": "websocket",
  1194. "Connection": "Upgrade",
  1195. "Sec-WebSocket-Key": to_unicode(self.key),
  1196. "Sec-WebSocket-Version": "13",
  1197. }
  1198. )
  1199. if subprotocols is not None:
  1200. request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols)
  1201. if compression_options is not None:
  1202. # Always offer to let the server set our max_wbits (and even though
  1203. # we don't offer it, we will accept a client_no_context_takeover
  1204. # from the server).
  1205. # TODO: set server parameters for deflate extension
  1206. # if requested in self.compression_options.
  1207. request.headers["Sec-WebSocket-Extensions"] = (
  1208. "permessage-deflate; client_max_window_bits"
  1209. )
  1210. # Websocket connection is currently unable to follow redirects
  1211. request.follow_redirects = False
  1212. self.tcp_client = TCPClient(resolver=resolver)
  1213. super().__init__(
  1214. None,
  1215. request,
  1216. lambda: None,
  1217. self._on_http_response,
  1218. 104857600,
  1219. self.tcp_client,
  1220. 65536,
  1221. 104857600,
  1222. )
  1223. def __del__(self) -> None:
  1224. if self.protocol is not None:
  1225. # Unclosed client connections can sometimes log "task was destroyed but
  1226. # was pending" warnings if shutdown strikes at the wrong time (such as
  1227. # while a ping is being processed due to ping_interval). Log our own
  1228. # warning to make it a little more deterministic (although it's still
  1229. # dependent on GC timing).
  1230. warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning)
  1231. def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
  1232. """Closes the websocket connection.
  1233. ``code`` and ``reason`` are documented under
  1234. `WebSocketHandler.close`.
  1235. .. versionadded:: 3.2
  1236. .. versionchanged:: 4.0
  1237. Added the ``code`` and ``reason`` arguments.
  1238. """
  1239. if self.protocol is not None:
  1240. self.protocol.close(code, reason)
  1241. self.protocol = None # type: ignore
  1242. def on_connection_close(self) -> None:
  1243. if not self.connect_future.done():
  1244. self.connect_future.set_exception(StreamClosedError())
  1245. self._on_message(None)
  1246. self.tcp_client.close()
  1247. super().on_connection_close()
  1248. def on_ws_connection_close(
  1249. self, close_code: Optional[int] = None, close_reason: Optional[str] = None
  1250. ) -> None:
  1251. self.close_code = close_code
  1252. self.close_reason = close_reason
  1253. self.on_connection_close()
  1254. def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
  1255. if not self.connect_future.done():
  1256. if response.error:
  1257. self.connect_future.set_exception(response.error)
  1258. else:
  1259. self.connect_future.set_exception(
  1260. WebSocketError("Non-websocket response")
  1261. )
  1262. async def headers_received(
  1263. self,
  1264. start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
  1265. headers: httputil.HTTPHeaders,
  1266. ) -> None:
  1267. assert isinstance(start_line, httputil.ResponseStartLine)
  1268. if start_line.code != 101:
  1269. await super().headers_received(start_line, headers)
  1270. return
  1271. if self._timeout is not None:
  1272. self.io_loop.remove_timeout(self._timeout)
  1273. self._timeout = None
  1274. self.headers = headers
  1275. self.protocol = self.get_websocket_protocol()
  1276. self.protocol._process_server_headers(self.key, self.headers)
  1277. self.protocol.stream = self.connection.detach()
  1278. IOLoop.current().add_callback(self.protocol._receive_frame_loop)
  1279. self.protocol.start_pinging()
  1280. # Once we've taken over the connection, clear the final callback
  1281. # we set on the http request. This deactivates the error handling
  1282. # in simple_httpclient that would otherwise interfere with our
  1283. # ability to see exceptions.
  1284. self.final_callback = None # type: ignore
  1285. future_set_result_unless_cancelled(self.connect_future, self)
  1286. def write_message(
  1287. self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False
  1288. ) -> "Future[None]":
  1289. """Sends a message to the WebSocket server.
  1290. If the stream is closed, raises `WebSocketClosedError`.
  1291. Returns a `.Future` which can be used for flow control.
  1292. .. versionchanged:: 5.0
  1293. Exception raised on a closed stream changed from `.StreamClosedError`
  1294. to `WebSocketClosedError`.
  1295. """
  1296. if self.protocol is None:
  1297. raise WebSocketClosedError("Client connection has been closed")
  1298. return self.protocol.write_message(message, binary=binary)
  1299. def read_message(
  1300. self,
  1301. callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None,
  1302. ) -> Awaitable[Union[None, str, bytes]]:
  1303. """Reads a message from the WebSocket server.
  1304. If on_message_callback was specified at WebSocket
  1305. initialization, this function will never return messages
  1306. Returns a future whose result is the message, or None
  1307. if the connection is closed. If a callback argument
  1308. is given it will be called with the future when it is
  1309. ready.
  1310. """
  1311. awaitable = self.read_queue.get()
  1312. if callback is not None:
  1313. self.io_loop.add_future(asyncio.ensure_future(awaitable), callback)
  1314. return awaitable
  1315. def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
  1316. return self._on_message(message)
  1317. def _on_message(
  1318. self, message: Union[None, str, bytes]
  1319. ) -> Optional[Awaitable[None]]:
  1320. if self._on_message_callback:
  1321. self._on_message_callback(message)
  1322. return None
  1323. else:
  1324. return self.read_queue.put(message)
  1325. def ping(self, data: bytes = b"") -> None:
  1326. """Send ping frame to the remote end.
  1327. The data argument allows a small amount of data (up to 125
  1328. bytes) to be sent as a part of the ping message. Note that not
  1329. all websocket implementations expose this data to
  1330. applications.
  1331. Consider using the ``ping_interval`` argument to
  1332. `websocket_connect` instead of sending pings manually.
  1333. .. versionadded:: 5.1
  1334. """
  1335. data = utf8(data)
  1336. if self.protocol is None:
  1337. raise WebSocketClosedError()
  1338. self.protocol.write_ping(data)
  1339. def on_pong(self, data: bytes) -> None:
  1340. pass
  1341. def on_ping(self, data: bytes) -> None:
  1342. pass
  1343. def get_websocket_protocol(self) -> WebSocketProtocol:
  1344. return WebSocketProtocol13(self, mask_outgoing=True, params=self.params)
  1345. @property
  1346. def selected_subprotocol(self) -> Optional[str]:
  1347. """The subprotocol selected by the server.
  1348. .. versionadded:: 5.1
  1349. """
  1350. return self.protocol.selected_subprotocol
  1351. def log_exception(
  1352. self,
  1353. typ: "Optional[Type[BaseException]]",
  1354. value: Optional[BaseException],
  1355. tb: Optional[TracebackType],
  1356. ) -> None:
  1357. assert typ is not None
  1358. assert value is not None
  1359. app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb))
  1360. def websocket_connect(
  1361. url: Union[str, httpclient.HTTPRequest],
  1362. callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None,
  1363. connect_timeout: Optional[float] = None,
  1364. on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None,
  1365. compression_options: Optional[Dict[str, Any]] = None,
  1366. ping_interval: Optional[float] = None,
  1367. ping_timeout: Optional[float] = None,
  1368. max_message_size: int = _default_max_message_size,
  1369. subprotocols: Optional[List[str]] = None,
  1370. resolver: Optional[Resolver] = None,
  1371. ) -> "Awaitable[WebSocketClientConnection]":
  1372. """Client-side websocket support.
  1373. Takes a url and returns a Future whose result is a
  1374. `WebSocketClientConnection`.
  1375. ``compression_options`` is interpreted in the same way as the
  1376. return value of `.WebSocketHandler.get_compression_options`.
  1377. The connection supports two styles of operation. In the coroutine
  1378. style, the application typically calls
  1379. `~.WebSocketClientConnection.read_message` in a loop::
  1380. conn = yield websocket_connect(url)
  1381. while True:
  1382. msg = yield conn.read_message()
  1383. if msg is None: break
  1384. # Do something with msg
  1385. In the callback style, pass an ``on_message_callback`` to
  1386. ``websocket_connect``. In both styles, a message of ``None``
  1387. indicates that the connection has been closed.
  1388. ``subprotocols`` may be a list of strings specifying proposed
  1389. subprotocols. The selected protocol may be found on the
  1390. ``selected_subprotocol`` attribute of the connection object
  1391. when the connection is complete.
  1392. .. versionchanged:: 3.2
  1393. Also accepts ``HTTPRequest`` objects in place of urls.
  1394. .. versionchanged:: 4.1
  1395. Added ``compression_options`` and ``on_message_callback``.
  1396. .. versionchanged:: 4.5
  1397. Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size``
  1398. arguments, which have the same meaning as in `WebSocketHandler`.
  1399. .. versionchanged:: 5.0
  1400. The ``io_loop`` argument (deprecated since version 4.1) has been removed.
  1401. .. versionchanged:: 5.1
  1402. Added the ``subprotocols`` argument.
  1403. .. versionchanged:: 6.3
  1404. Added the ``resolver`` argument.
  1405. .. deprecated:: 6.5
  1406. The ``callback`` argument is deprecated and will be removed in Tornado 7.0.
  1407. Use the returned Future instead. Note that ``on_message_callback`` is not
  1408. deprecated and may still be used.
  1409. """
  1410. if isinstance(url, httpclient.HTTPRequest):
  1411. assert connect_timeout is None
  1412. request = url
  1413. # Copy and convert the headers dict/object (see comments in
  1414. # AsyncHTTPClient.fetch)
  1415. request.headers = httputil.HTTPHeaders(request.headers)
  1416. else:
  1417. request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
  1418. request = cast(
  1419. httpclient.HTTPRequest,
  1420. httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS),
  1421. )
  1422. conn = WebSocketClientConnection(
  1423. request,
  1424. on_message_callback=on_message_callback,
  1425. compression_options=compression_options,
  1426. ping_interval=ping_interval,
  1427. ping_timeout=ping_timeout,
  1428. max_message_size=max_message_size,
  1429. subprotocols=subprotocols,
  1430. resolver=resolver,
  1431. )
  1432. if callback is not None:
  1433. warnings.warn(
  1434. "The callback argument to websocket_connect is deprecated. "
  1435. "Use the returned Future instead.",
  1436. DeprecationWarning,
  1437. stacklevel=2,
  1438. )
  1439. IOLoop.current().add_future(conn.connect_future, callback)
  1440. return conn.connect_future