websocket_test.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989
  1. import asyncio
  2. import contextlib
  3. import datetime
  4. import functools
  5. import socket
  6. import traceback
  7. import typing
  8. import unittest
  9. from tornado.concurrent import Future
  10. from tornado import gen
  11. from tornado.httpclient import HTTPError, HTTPRequest
  12. from tornado.locks import Event
  13. from tornado.log import gen_log, app_log
  14. from tornado.netutil import Resolver
  15. from tornado.simple_httpclient import SimpleAsyncHTTPClient
  16. from tornado.template import DictLoader
  17. from tornado.test.util import abstract_base_test, ignore_deprecation
  18. from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
  19. from tornado.web import Application, RequestHandler
  20. try:
  21. import tornado.websocket # noqa: F401
  22. from tornado.util import _websocket_mask_python
  23. except ImportError:
  24. # The unittest module presents misleading errors on ImportError
  25. # (it acts as if websocket_test could not be found, hiding the underlying
  26. # error). If we get an ImportError here (which could happen due to
  27. # TORNADO_EXTENSION=1), print some extra information before failing.
  28. traceback.print_exc()
  29. raise
  30. from tornado.websocket import (
  31. WebSocketHandler,
  32. websocket_connect,
  33. WebSocketError,
  34. WebSocketClosedError,
  35. )
  36. try:
  37. from tornado import speedups
  38. except ImportError:
  39. speedups = None # type: ignore
  40. class TestWebSocketHandler(WebSocketHandler):
  41. """Base class for testing handlers that exposes the on_close event.
  42. This allows for tests to see the close code and reason on the
  43. server side.
  44. """
  45. def initialize(self, close_future=None, compression_options=None):
  46. self.close_future = close_future
  47. self.compression_options = compression_options
  48. def get_compression_options(self):
  49. return self.compression_options
  50. def on_close(self):
  51. if self.close_future is not None:
  52. self.close_future.set_result((self.close_code, self.close_reason))
  53. class EchoHandler(TestWebSocketHandler):
  54. @gen.coroutine
  55. def on_message(self, message):
  56. try:
  57. yield self.write_message(message, isinstance(message, bytes))
  58. except asyncio.CancelledError:
  59. pass
  60. except WebSocketClosedError:
  61. pass
  62. class ErrorInOnMessageHandler(TestWebSocketHandler):
  63. def on_message(self, message):
  64. 1 / 0
  65. class HeaderHandler(TestWebSocketHandler):
  66. def open(self):
  67. methods_to_test = [
  68. functools.partial(self.write, "This should not work"),
  69. functools.partial(self.redirect, "http://localhost/elsewhere"),
  70. functools.partial(self.set_header, "X-Test", ""),
  71. functools.partial(self.set_cookie, "Chocolate", "Chip"),
  72. functools.partial(self.set_status, 503),
  73. self.flush,
  74. self.finish,
  75. ]
  76. for method in methods_to_test:
  77. try:
  78. # In a websocket context, many RequestHandler methods
  79. # raise RuntimeErrors.
  80. method() # type: ignore
  81. raise Exception("did not get expected exception")
  82. except RuntimeError:
  83. pass
  84. self.write_message(self.request.headers.get("X-Test", ""))
  85. class HeaderEchoHandler(TestWebSocketHandler):
  86. def set_default_headers(self):
  87. self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
  88. def prepare(self):
  89. for k, v in self.request.headers.get_all():
  90. if k.lower().startswith("x-test"):
  91. self.set_header(k, v)
  92. class NonWebSocketHandler(RequestHandler):
  93. def get(self):
  94. self.write("ok")
  95. class RedirectHandler(RequestHandler):
  96. def get(self):
  97. self.redirect("/echo")
  98. class CloseReasonHandler(TestWebSocketHandler):
  99. def open(self):
  100. self.on_close_called = False
  101. self.close(1001, "goodbye")
  102. class AsyncPrepareHandler(TestWebSocketHandler):
  103. @gen.coroutine
  104. def prepare(self):
  105. yield gen.moment
  106. def on_message(self, message):
  107. self.write_message(message)
  108. class PathArgsHandler(TestWebSocketHandler):
  109. def open(self, arg):
  110. self.write_message(arg)
  111. class CoroutineOnMessageHandler(TestWebSocketHandler):
  112. def initialize(self, **kwargs):
  113. super().initialize(**kwargs)
  114. self.sleeping = 0
  115. @gen.coroutine
  116. def on_message(self, message):
  117. if self.sleeping > 0:
  118. self.write_message("another coroutine is already sleeping")
  119. self.sleeping += 1
  120. yield gen.sleep(0.01)
  121. self.sleeping -= 1
  122. self.write_message(message)
  123. class RenderMessageHandler(TestWebSocketHandler):
  124. def on_message(self, message):
  125. self.write_message(self.render_string("message.html", message=message))
  126. class SubprotocolHandler(TestWebSocketHandler):
  127. def initialize(self, **kwargs):
  128. super().initialize(**kwargs)
  129. self.select_subprotocol_called = False
  130. def select_subprotocol(self, subprotocols):
  131. if self.select_subprotocol_called:
  132. raise Exception("select_subprotocol called twice")
  133. self.select_subprotocol_called = True
  134. if "goodproto" in subprotocols:
  135. return "goodproto"
  136. return None
  137. def open(self):
  138. if not self.select_subprotocol_called:
  139. raise Exception("select_subprotocol not called")
  140. self.write_message("subprotocol=%s" % self.selected_subprotocol)
  141. class OpenCoroutineHandler(TestWebSocketHandler):
  142. def initialize(self, test, **kwargs):
  143. super().initialize(**kwargs)
  144. self.test = test
  145. self.open_finished = False
  146. @gen.coroutine
  147. def open(self):
  148. yield self.test.message_sent.wait()
  149. yield gen.sleep(0.010)
  150. self.open_finished = True
  151. def on_message(self, message):
  152. if not self.open_finished:
  153. raise Exception("on_message called before open finished")
  154. self.write_message("ok")
  155. class ErrorInOpenHandler(TestWebSocketHandler):
  156. def open(self):
  157. raise Exception("boom")
  158. class ErrorInAsyncOpenHandler(TestWebSocketHandler):
  159. async def open(self):
  160. await asyncio.sleep(0)
  161. raise Exception("boom")
  162. class NoDelayHandler(TestWebSocketHandler):
  163. def open(self):
  164. self.set_nodelay(True)
  165. self.write_message("hello")
  166. class WebSocketBaseTestCase(AsyncHTTPTestCase):
  167. def setUp(self):
  168. super().setUp()
  169. self.conns_to_close = []
  170. def tearDown(self):
  171. for conn in self.conns_to_close:
  172. conn.close()
  173. super().tearDown()
  174. @gen.coroutine
  175. def ws_connect(self, path, **kwargs):
  176. ws = yield websocket_connect(
  177. "ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
  178. )
  179. self.conns_to_close.append(ws)
  180. raise gen.Return(ws)
  181. class WebSocketTest(WebSocketBaseTestCase):
  182. def get_app(self):
  183. self.close_future = Future() # type: Future[None]
  184. return Application(
  185. [
  186. ("/echo", EchoHandler, dict(close_future=self.close_future)),
  187. ("/non_ws", NonWebSocketHandler),
  188. ("/redirect", RedirectHandler),
  189. ("/header", HeaderHandler, dict(close_future=self.close_future)),
  190. (
  191. "/header_echo",
  192. HeaderEchoHandler,
  193. dict(close_future=self.close_future),
  194. ),
  195. (
  196. "/close_reason",
  197. CloseReasonHandler,
  198. dict(close_future=self.close_future),
  199. ),
  200. (
  201. "/error_in_on_message",
  202. ErrorInOnMessageHandler,
  203. dict(close_future=self.close_future),
  204. ),
  205. (
  206. "/async_prepare",
  207. AsyncPrepareHandler,
  208. dict(close_future=self.close_future),
  209. ),
  210. (
  211. "/path_args/(.*)",
  212. PathArgsHandler,
  213. dict(close_future=self.close_future),
  214. ),
  215. (
  216. "/coroutine",
  217. CoroutineOnMessageHandler,
  218. dict(close_future=self.close_future),
  219. ),
  220. ("/render", RenderMessageHandler, dict(close_future=self.close_future)),
  221. (
  222. "/subprotocol",
  223. SubprotocolHandler,
  224. dict(close_future=self.close_future),
  225. ),
  226. (
  227. "/open_coroutine",
  228. OpenCoroutineHandler,
  229. dict(close_future=self.close_future, test=self),
  230. ),
  231. ("/error_in_open", ErrorInOpenHandler),
  232. ("/error_in_async_open", ErrorInAsyncOpenHandler),
  233. ("/nodelay", NoDelayHandler),
  234. ],
  235. template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
  236. )
  237. def get_http_client(self):
  238. # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
  239. return SimpleAsyncHTTPClient()
  240. def tearDown(self):
  241. super().tearDown()
  242. RequestHandler._template_loaders.clear()
  243. def test_http_request(self):
  244. # WS server, HTTP client.
  245. response = self.fetch("/echo")
  246. self.assertEqual(response.code, 400)
  247. def test_missing_websocket_key(self):
  248. response = self.fetch(
  249. "/echo",
  250. headers={
  251. "Connection": "Upgrade",
  252. "Upgrade": "WebSocket",
  253. "Sec-WebSocket-Version": "13",
  254. },
  255. )
  256. self.assertEqual(response.code, 400)
  257. def test_bad_websocket_version(self):
  258. response = self.fetch(
  259. "/echo",
  260. headers={
  261. "Connection": "Upgrade",
  262. "Upgrade": "WebSocket",
  263. "Sec-WebSocket-Version": "12",
  264. },
  265. )
  266. self.assertEqual(response.code, 426)
  267. @gen_test
  268. def test_websocket_gen(self):
  269. ws = yield self.ws_connect("/echo")
  270. yield ws.write_message("hello")
  271. response = yield ws.read_message()
  272. self.assertEqual(response, "hello")
  273. def test_websocket_callbacks(self):
  274. with ignore_deprecation():
  275. websocket_connect(
  276. "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
  277. )
  278. ws = self.wait().result()
  279. ws.write_message("hello")
  280. ws.read_message(self.stop)
  281. response = self.wait().result()
  282. self.assertEqual(response, "hello")
  283. self.close_future.add_done_callback(lambda f: self.stop())
  284. ws.close()
  285. self.wait()
  286. @gen_test
  287. def test_binary_message(self):
  288. ws = yield self.ws_connect("/echo")
  289. ws.write_message(b"hello \xe9", binary=True)
  290. response = yield ws.read_message()
  291. self.assertEqual(response, b"hello \xe9")
  292. @gen_test
  293. def test_unicode_message(self):
  294. ws = yield self.ws_connect("/echo")
  295. ws.write_message("hello \u00e9")
  296. response = yield ws.read_message()
  297. self.assertEqual(response, "hello \u00e9")
  298. @gen_test
  299. def test_error_in_closed_client_write_message(self):
  300. ws = yield self.ws_connect("/echo")
  301. ws.close()
  302. with self.assertRaises(WebSocketClosedError):
  303. ws.write_message("hello \u00e9")
  304. @gen_test
  305. def test_render_message(self):
  306. ws = yield self.ws_connect("/render")
  307. ws.write_message("hello")
  308. response = yield ws.read_message()
  309. self.assertEqual(response, "<b>hello</b>")
  310. @gen_test
  311. def test_error_in_on_message(self):
  312. ws = yield self.ws_connect("/error_in_on_message")
  313. ws.write_message("hello")
  314. with ExpectLog(app_log, "Uncaught exception"):
  315. response = yield ws.read_message()
  316. self.assertIsNone(response)
  317. @gen_test
  318. def test_websocket_http_fail(self):
  319. with self.assertRaises(HTTPError) as cm:
  320. yield self.ws_connect("/notfound")
  321. self.assertEqual(cm.exception.code, 404)
  322. @gen_test
  323. def test_websocket_http_success(self):
  324. with self.assertRaises(WebSocketError):
  325. yield self.ws_connect("/non_ws")
  326. @gen_test
  327. def test_websocket_http_redirect(self):
  328. with self.assertRaises(HTTPError):
  329. yield self.ws_connect("/redirect")
  330. @gen_test
  331. def test_websocket_network_fail(self):
  332. sock, port = bind_unused_port()
  333. sock.close()
  334. with self.assertRaises(IOError):
  335. with ExpectLog(gen_log, ".*", required=False):
  336. yield websocket_connect(
  337. "ws://127.0.0.1:%d/" % port, connect_timeout=3600
  338. )
  339. @gen_test
  340. def test_websocket_close_buffered_data(self):
  341. with contextlib.closing(
  342. (yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()))
  343. ) as ws:
  344. ws.write_message("hello")
  345. ws.write_message("world")
  346. # Close the underlying stream.
  347. ws.stream.close()
  348. @gen_test
  349. def test_websocket_headers(self):
  350. # Ensure that arbitrary headers can be passed through websocket_connect.
  351. with contextlib.closing(
  352. (
  353. yield websocket_connect(
  354. HTTPRequest(
  355. "ws://127.0.0.1:%d/header" % self.get_http_port(),
  356. headers={"X-Test": "hello"},
  357. )
  358. )
  359. )
  360. ) as ws:
  361. response = yield ws.read_message()
  362. self.assertEqual(response, "hello")
  363. @gen_test
  364. def test_websocket_header_echo(self):
  365. # Ensure that headers can be returned in the response.
  366. # Specifically, that arbitrary headers passed through websocket_connect
  367. # can be returned.
  368. with contextlib.closing(
  369. (
  370. yield websocket_connect(
  371. HTTPRequest(
  372. "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
  373. headers={"X-Test-Hello": "hello"},
  374. )
  375. )
  376. )
  377. ) as ws:
  378. self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
  379. self.assertEqual(
  380. ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
  381. )
  382. @gen_test
  383. def test_server_close_reason(self):
  384. ws = yield self.ws_connect("/close_reason")
  385. msg = yield ws.read_message()
  386. # A message of None means the other side closed the connection.
  387. self.assertIs(msg, None)
  388. self.assertEqual(ws.close_code, 1001)
  389. self.assertEqual(ws.close_reason, "goodbye")
  390. # The on_close callback is called no matter which side closed.
  391. code, reason = yield self.close_future
  392. # The client echoed the close code it received to the server,
  393. # so the server's close code (returned via close_future) is
  394. # the same.
  395. self.assertEqual(code, 1001)
  396. @gen_test
  397. def test_client_close_reason(self):
  398. ws = yield self.ws_connect("/echo")
  399. ws.close(1001, "goodbye")
  400. code, reason = yield self.close_future
  401. self.assertEqual(code, 1001)
  402. self.assertEqual(reason, "goodbye")
  403. @gen_test
  404. def test_write_after_close(self):
  405. ws = yield self.ws_connect("/close_reason")
  406. msg = yield ws.read_message()
  407. self.assertIs(msg, None)
  408. with self.assertRaises(WebSocketClosedError):
  409. ws.write_message("hello")
  410. @gen_test
  411. def test_async_prepare(self):
  412. # Previously, an async prepare method triggered a bug that would
  413. # result in a timeout on test shutdown (and a memory leak).
  414. ws = yield self.ws_connect("/async_prepare")
  415. ws.write_message("hello")
  416. res = yield ws.read_message()
  417. self.assertEqual(res, "hello")
  418. @gen_test
  419. def test_path_args(self):
  420. ws = yield self.ws_connect("/path_args/hello")
  421. res = yield ws.read_message()
  422. self.assertEqual(res, "hello")
  423. @gen_test
  424. def test_coroutine(self):
  425. ws = yield self.ws_connect("/coroutine")
  426. # Send both messages immediately, coroutine must process one at a time.
  427. yield ws.write_message("hello1")
  428. yield ws.write_message("hello2")
  429. res = yield ws.read_message()
  430. self.assertEqual(res, "hello1")
  431. res = yield ws.read_message()
  432. self.assertEqual(res, "hello2")
  433. @gen_test
  434. def test_check_origin_valid_no_path(self):
  435. port = self.get_http_port()
  436. url = "ws://127.0.0.1:%d/echo" % port
  437. headers = {"Origin": "http://127.0.0.1:%d" % port}
  438. with contextlib.closing(
  439. (yield websocket_connect(HTTPRequest(url, headers=headers)))
  440. ) as ws:
  441. ws.write_message("hello")
  442. response = yield ws.read_message()
  443. self.assertEqual(response, "hello")
  444. @gen_test
  445. def test_check_origin_valid_with_path(self):
  446. port = self.get_http_port()
  447. url = "ws://127.0.0.1:%d/echo" % port
  448. headers = {"Origin": "http://127.0.0.1:%d/something" % port}
  449. with contextlib.closing(
  450. (yield websocket_connect(HTTPRequest(url, headers=headers)))
  451. ) as ws:
  452. ws.write_message("hello")
  453. response = yield ws.read_message()
  454. self.assertEqual(response, "hello")
  455. @gen_test
  456. def test_check_origin_invalid_partial_url(self):
  457. port = self.get_http_port()
  458. url = "ws://127.0.0.1:%d/echo" % port
  459. headers = {"Origin": "127.0.0.1:%d" % port}
  460. with self.assertRaises(HTTPError) as cm:
  461. yield websocket_connect(HTTPRequest(url, headers=headers))
  462. self.assertEqual(cm.exception.code, 403)
  463. @gen_test
  464. def test_check_origin_invalid(self):
  465. port = self.get_http_port()
  466. url = "ws://127.0.0.1:%d/echo" % port
  467. # Host is 127.0.0.1, which should not be accessible from some other
  468. # domain
  469. headers = {"Origin": "http://somewhereelse.com"}
  470. with self.assertRaises(HTTPError) as cm:
  471. yield websocket_connect(HTTPRequest(url, headers=headers))
  472. self.assertEqual(cm.exception.code, 403)
  473. @gen_test
  474. def test_check_origin_invalid_subdomains(self):
  475. port = self.get_http_port()
  476. # CaresResolver may return ipv6-only results for localhost, but our
  477. # server is only running on ipv4. Test for this edge case and skip
  478. # the test if it happens.
  479. addrinfo = yield Resolver().resolve("localhost", port)
  480. families = {addr[0] for addr in addrinfo}
  481. if socket.AF_INET not in families:
  482. self.skipTest("localhost does not resolve to ipv4")
  483. return
  484. url = "ws://localhost:%d/echo" % port
  485. # Subdomains should be disallowed by default. If we could pass a
  486. # resolver to websocket_connect we could test sibling domains as well.
  487. headers = {"Origin": "http://subtenant.localhost"}
  488. with self.assertRaises(HTTPError) as cm:
  489. yield websocket_connect(HTTPRequest(url, headers=headers))
  490. self.assertEqual(cm.exception.code, 403)
  491. @gen_test
  492. def test_subprotocols(self):
  493. ws = yield self.ws_connect(
  494. "/subprotocol", subprotocols=["badproto", "goodproto"]
  495. )
  496. self.assertEqual(ws.selected_subprotocol, "goodproto")
  497. res = yield ws.read_message()
  498. self.assertEqual(res, "subprotocol=goodproto")
  499. @gen_test
  500. def test_subprotocols_not_offered(self):
  501. ws = yield self.ws_connect("/subprotocol")
  502. self.assertIs(ws.selected_subprotocol, None)
  503. res = yield ws.read_message()
  504. self.assertEqual(res, "subprotocol=None")
  505. @gen_test
  506. def test_open_coroutine(self):
  507. self.message_sent = Event()
  508. ws = yield self.ws_connect("/open_coroutine")
  509. yield ws.write_message("hello")
  510. self.message_sent.set()
  511. res = yield ws.read_message()
  512. self.assertEqual(res, "ok")
  513. @gen_test
  514. def test_error_in_open(self):
  515. with ExpectLog(app_log, "Uncaught exception"):
  516. ws = yield self.ws_connect("/error_in_open")
  517. res = yield ws.read_message()
  518. self.assertIsNone(res)
  519. @gen_test
  520. def test_error_in_async_open(self):
  521. with ExpectLog(app_log, "Uncaught exception"):
  522. ws = yield self.ws_connect("/error_in_async_open")
  523. res = yield ws.read_message()
  524. self.assertIsNone(res)
  525. @gen_test
  526. def test_nodelay(self):
  527. ws = yield self.ws_connect("/nodelay")
  528. res = yield ws.read_message()
  529. self.assertEqual(res, "hello")
  530. class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
  531. def initialize(self, **kwargs):
  532. super().initialize(**kwargs)
  533. self.sleeping = 0
  534. async def on_message(self, message):
  535. if self.sleeping > 0:
  536. self.write_message("another coroutine is already sleeping")
  537. self.sleeping += 1
  538. await gen.sleep(0.01)
  539. self.sleeping -= 1
  540. self.write_message(message)
  541. class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
  542. def get_app(self):
  543. return Application([("/native", NativeCoroutineOnMessageHandler)])
  544. @gen_test
  545. def test_native_coroutine(self):
  546. ws = yield self.ws_connect("/native")
  547. # Send both messages immediately, coroutine must process one at a time.
  548. yield ws.write_message("hello1")
  549. yield ws.write_message("hello2")
  550. res = yield ws.read_message()
  551. self.assertEqual(res, "hello1")
  552. res = yield ws.read_message()
  553. self.assertEqual(res, "hello2")
  554. @abstract_base_test
  555. class CompressionTestMixin(WebSocketBaseTestCase):
  556. MESSAGE = "Hello world. Testing 123 123"
  557. def get_app(self):
  558. class LimitedHandler(TestWebSocketHandler):
  559. @property
  560. def max_message_size(self):
  561. return 1024
  562. def on_message(self, message):
  563. self.write_message(str(len(message)))
  564. return Application(
  565. [
  566. (
  567. "/echo",
  568. EchoHandler,
  569. dict(compression_options=self.get_server_compression_options()),
  570. ),
  571. (
  572. "/limited",
  573. LimitedHandler,
  574. dict(compression_options=self.get_server_compression_options()),
  575. ),
  576. ]
  577. )
  578. def get_server_compression_options(self):
  579. return None
  580. def get_client_compression_options(self):
  581. return None
  582. def verify_wire_bytes(self, bytes_in: int, bytes_out: int) -> None:
  583. raise NotImplementedError()
  584. @gen_test
  585. def test_message_sizes(self):
  586. ws = yield self.ws_connect(
  587. "/echo", compression_options=self.get_client_compression_options()
  588. )
  589. # Send the same message three times so we can measure the
  590. # effect of the context_takeover options.
  591. for i in range(3):
  592. ws.write_message(self.MESSAGE)
  593. response = yield ws.read_message()
  594. self.assertEqual(response, self.MESSAGE)
  595. self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
  596. self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
  597. self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
  598. @gen_test
  599. def test_size_limit(self):
  600. ws = yield self.ws_connect(
  601. "/limited", compression_options=self.get_client_compression_options()
  602. )
  603. # Small messages pass through.
  604. ws.write_message("a" * 128)
  605. response = yield ws.read_message()
  606. self.assertEqual(response, "128")
  607. # This message is too big after decompression, but it compresses
  608. # down to a size that will pass the initial checks.
  609. ws.write_message("a" * 2048)
  610. response = yield ws.read_message()
  611. self.assertIsNone(response)
  612. @abstract_base_test
  613. class UncompressedTestMixin(CompressionTestMixin):
  614. """Specialization of CompressionTestMixin when we expect no compression."""
  615. def verify_wire_bytes(self, bytes_in, bytes_out):
  616. # Bytes out includes the 4-byte mask key per message.
  617. self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
  618. self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
  619. class NoCompressionTest(UncompressedTestMixin):
  620. pass
  621. # If only one side tries to compress, the extension is not negotiated.
  622. class ServerOnlyCompressionTest(UncompressedTestMixin):
  623. def get_server_compression_options(self):
  624. return {}
  625. class ClientOnlyCompressionTest(UncompressedTestMixin):
  626. def get_client_compression_options(self):
  627. return {}
  628. class DefaultCompressionTest(CompressionTestMixin):
  629. def get_server_compression_options(self):
  630. return {}
  631. def get_client_compression_options(self):
  632. return {}
  633. def verify_wire_bytes(self, bytes_in, bytes_out):
  634. self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
  635. self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
  636. # Bytes out includes the 4 bytes mask key per message.
  637. self.assertEqual(bytes_out, bytes_in + 12)
  638. @abstract_base_test
  639. class MaskFunctionMixin(unittest.TestCase):
  640. # Subclasses should define self.mask(mask, data)
  641. def mask(self, mask: bytes, data: bytes) -> bytes:
  642. raise NotImplementedError()
  643. def test_mask(self: typing.Any):
  644. self.assertEqual(self.mask(b"abcd", b""), b"")
  645. self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
  646. self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
  647. self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
  648. # Include test cases with \x00 bytes (to ensure that the C
  649. # extension isn't depending on null-terminated strings) and
  650. # bytes with the high bit set (to smoke out signedness issues).
  651. self.assertEqual(
  652. self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
  653. b"\xff\xfa\xff\xff\xfe\xfb",
  654. )
  655. self.assertEqual(
  656. self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
  657. b"\xff\xfa\xff\xff\xfb\xfe",
  658. )
  659. class PythonMaskFunctionTest(MaskFunctionMixin):
  660. def mask(self, mask, data):
  661. return _websocket_mask_python(mask, data)
  662. @unittest.skipIf(speedups is None, "tornado.speedups module not present")
  663. class CythonMaskFunctionTest(MaskFunctionMixin):
  664. def mask(self, mask, data):
  665. return speedups.websocket_mask(mask, data)
  666. class ServerPeriodicPingTest(WebSocketBaseTestCase):
  667. def get_app(self):
  668. class PingHandler(TestWebSocketHandler):
  669. def on_pong(self, data):
  670. self.write_message("got pong")
  671. return Application(
  672. [("/", PingHandler)],
  673. websocket_ping_interval=0.01,
  674. websocket_ping_timeout=0,
  675. )
  676. @gen_test
  677. def test_server_ping(self):
  678. ws = yield self.ws_connect("/")
  679. for i in range(3):
  680. response = yield ws.read_message()
  681. self.assertEqual(response, "got pong")
  682. # TODO: test that the connection gets closed if ping responses stop.
  683. class ClientPeriodicPingTest(WebSocketBaseTestCase):
  684. def get_app(self):
  685. class PingHandler(TestWebSocketHandler):
  686. def on_ping(self, data):
  687. self.write_message("got ping")
  688. return Application([("/", PingHandler)])
  689. @gen_test
  690. def test_client_ping(self):
  691. ws = yield self.ws_connect("/", ping_interval=0.01, ping_timeout=0)
  692. for i in range(3):
  693. response = yield ws.read_message()
  694. self.assertEqual(response, "got ping")
  695. ws.close()
  696. class ServerPingTimeoutTest(WebSocketBaseTestCase):
  697. def get_app(self):
  698. self.handlers: list[WebSocketHandler] = []
  699. test = self
  700. class PingHandler(TestWebSocketHandler):
  701. def initialize(self, close_future=None, compression_options=None):
  702. self.handlers = test.handlers
  703. # capture the handler instance so we can interrogate it later
  704. self.handlers.append(self)
  705. return super().initialize(
  706. close_future=close_future, compression_options=compression_options
  707. )
  708. app = Application([("/", PingHandler)])
  709. return app
  710. @staticmethod
  711. def install_hook(ws):
  712. """Optionally suppress the client's "pong" response."""
  713. ws.drop_pongs = False
  714. ws.pongs_received = 0
  715. def wrapper(fcn):
  716. def _inner(opcode: int, data: bytes):
  717. if opcode == 0xA: # NOTE: 0x9=ping, 0xA=pong
  718. ws.pongs_received += 1
  719. if ws.drop_pongs:
  720. # prevent pong responses
  721. return
  722. # leave all other responses unchanged
  723. return fcn(opcode, data)
  724. return _inner
  725. ws.protocol._handle_message = wrapper(ws.protocol._handle_message)
  726. @gen_test
  727. def test_client_ping_timeout(self):
  728. # websocket client
  729. interval = 0.2
  730. ws = yield self.ws_connect(
  731. "/", ping_interval=interval, ping_timeout=interval / 4
  732. )
  733. self.install_hook(ws)
  734. # websocket handler (server side)
  735. handler = self.handlers[0]
  736. for _ in range(5):
  737. # wait for the ping period
  738. yield gen.sleep(interval)
  739. # connection should still be open from the server end
  740. self.assertIsNone(handler.close_code)
  741. self.assertIsNone(handler.close_reason)
  742. # connection should still be open from the client end
  743. assert ws.protocol.close_code is None
  744. # Check that our hook is intercepting messages; allow for
  745. # some variance in timing (due to e.g. cpu load)
  746. self.assertGreaterEqual(ws.pongs_received, 4)
  747. # suppress the pong response message
  748. ws.drop_pongs = True
  749. # give the server time to register this
  750. yield gen.sleep(interval * 1.5)
  751. # connection should be closed from the server side
  752. self.assertEqual(handler.close_code, 1000)
  753. self.assertEqual(handler.close_reason, "ping timed out")
  754. # client should have received a close operation
  755. self.assertEqual(ws.protocol.close_code, 1000)
  756. class PingCalculationTest(unittest.TestCase):
  757. def test_ping_sleep_time(self):
  758. from tornado.websocket import WebSocketProtocol13
  759. now = datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
  760. interval = 10 # seconds
  761. last_ping_time = datetime.datetime(
  762. 2025, 1, 1, 11, 59, 54, tzinfo=datetime.timezone.utc
  763. )
  764. sleep_time = WebSocketProtocol13.ping_sleep_time(
  765. last_ping_time=last_ping_time.timestamp(),
  766. interval=interval,
  767. now=now.timestamp(),
  768. )
  769. self.assertEqual(sleep_time, 4)
  770. class ManualPingTest(WebSocketBaseTestCase):
  771. def get_app(self):
  772. class PingHandler(TestWebSocketHandler):
  773. def on_ping(self, data):
  774. self.write_message(data, binary=isinstance(data, bytes))
  775. return Application([("/", PingHandler)])
  776. @gen_test
  777. def test_manual_ping(self):
  778. ws = yield self.ws_connect("/")
  779. self.assertRaises(ValueError, ws.ping, "a" * 126)
  780. ws.ping("hello")
  781. resp = yield ws.read_message()
  782. # on_ping always sees bytes.
  783. self.assertEqual(resp, b"hello")
  784. ws.ping(b"binary hello")
  785. resp = yield ws.read_message()
  786. self.assertEqual(resp, b"binary hello")
  787. class MaxMessageSizeTest(WebSocketBaseTestCase):
  788. def get_app(self):
  789. return Application([("/", EchoHandler)], websocket_max_message_size=1024)
  790. @gen_test
  791. def test_large_message(self):
  792. ws = yield self.ws_connect("/")
  793. # Write a message that is allowed.
  794. msg = "a" * 1024
  795. ws.write_message(msg)
  796. resp = yield ws.read_message()
  797. self.assertEqual(resp, msg)
  798. # Write a message that is too large.
  799. ws.write_message(msg + "b")
  800. resp = yield ws.read_message()
  801. # A message of None means the other side closed the connection.
  802. self.assertIs(resp, None)
  803. self.assertEqual(ws.close_code, 1009)
  804. self.assertEqual(ws.close_reason, "message too big")
  805. # TODO: Needs tests of messages split over multiple
  806. # continuation frames.