tcpserver_test.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import socket
  2. import subprocess
  3. import sys
  4. import textwrap
  5. import unittest
  6. from tornado import gen
  7. from tornado.iostream import IOStream
  8. from tornado.log import app_log
  9. from tornado.tcpserver import TCPServer
  10. from tornado.test.util import skipIfNonUnix
  11. from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
  12. from typing import Tuple
  13. class TCPServerTest(AsyncTestCase):
  14. @gen_test
  15. def test_handle_stream_coroutine_logging(self):
  16. # handle_stream may be a coroutine and any exception in its
  17. # Future will be logged.
  18. class TestServer(TCPServer):
  19. @gen.coroutine
  20. def handle_stream(self, stream, address):
  21. yield stream.read_bytes(len(b"hello"))
  22. stream.close()
  23. 1 / 0
  24. server = client = None
  25. try:
  26. sock, port = bind_unused_port()
  27. server = TestServer()
  28. server.add_socket(sock)
  29. client = IOStream(socket.socket())
  30. with ExpectLog(app_log, "Exception in callback"):
  31. yield client.connect(("localhost", port))
  32. yield client.write(b"hello")
  33. yield client.read_until_close()
  34. yield gen.moment
  35. finally:
  36. if server is not None:
  37. server.stop()
  38. if client is not None:
  39. client.close()
  40. @gen_test
  41. def test_handle_stream_native_coroutine(self):
  42. # handle_stream may be a native coroutine.
  43. class TestServer(TCPServer):
  44. async def handle_stream(self, stream, address):
  45. stream.write(b"data")
  46. stream.close()
  47. sock, port = bind_unused_port()
  48. server = TestServer()
  49. server.add_socket(sock)
  50. client = IOStream(socket.socket())
  51. yield client.connect(("localhost", port))
  52. result = yield client.read_until_close()
  53. self.assertEqual(result, b"data")
  54. server.stop()
  55. client.close()
  56. def test_stop_twice(self):
  57. sock, port = bind_unused_port()
  58. server = TCPServer()
  59. server.add_socket(sock)
  60. server.stop()
  61. server.stop()
  62. @gen_test
  63. def test_stop_in_callback(self):
  64. # Issue #2069: calling server.stop() in a loop callback should not
  65. # raise EBADF when the loop handles other server connection
  66. # requests in the same loop iteration
  67. class TestServer(TCPServer):
  68. @gen.coroutine
  69. def handle_stream(self, stream, address):
  70. server.stop() # type: ignore
  71. yield stream.read_until_close()
  72. sock, port = bind_unused_port()
  73. server = TestServer()
  74. server.add_socket(sock)
  75. server_addr = ("localhost", port)
  76. N = 40
  77. clients = [IOStream(socket.socket()) for i in range(N)]
  78. connected_clients = []
  79. @gen.coroutine
  80. def connect(c):
  81. try:
  82. yield c.connect(server_addr)
  83. except OSError:
  84. pass
  85. else:
  86. connected_clients.append(c)
  87. yield [connect(c) for c in clients]
  88. self.assertGreater(len(connected_clients), 0, "all clients failed connecting")
  89. try:
  90. if len(connected_clients) == N:
  91. # Ideally we'd make the test deterministic, but we're testing
  92. # for a race condition in combination with the system's TCP stack...
  93. self.skipTest(
  94. "at least one client should fail connecting "
  95. "for the test to be meaningful"
  96. )
  97. finally:
  98. for c in connected_clients:
  99. c.close()
  100. # Here tearDown() would re-raise the EBADF encountered in the IO loop
  101. @skipIfNonUnix
  102. class TestMultiprocess(unittest.TestCase):
  103. # These tests verify that the two multiprocess examples from the
  104. # TCPServer docs work. Both tests start a server with three worker
  105. # processes, each of which prints its task id to stdout (a single
  106. # byte, so we don't have to worry about atomicity of the shared
  107. # stdout stream) and then exits.
  108. def run_subproc(self, code: str) -> Tuple[str, str]:
  109. try:
  110. result = subprocess.run(
  111. [sys.executable, "-Werror::DeprecationWarning"],
  112. capture_output=True,
  113. input=code,
  114. encoding="utf8",
  115. check=True,
  116. )
  117. except subprocess.CalledProcessError as e:
  118. raise RuntimeError(
  119. f"Process returned {e.returncode} stdout={e.stdout} stderr={e.stderr}"
  120. ) from e
  121. return result.stdout, result.stderr
  122. def test_listen_single(self):
  123. # As a sanity check, run the single-process version through this test
  124. # harness too.
  125. code = textwrap.dedent(
  126. """
  127. import asyncio
  128. from tornado.tcpserver import TCPServer
  129. async def main():
  130. server = TCPServer()
  131. server.listen(0, address='127.0.0.1')
  132. asyncio.run(main())
  133. print('012', end='')
  134. """
  135. )
  136. out, err = self.run_subproc(code)
  137. self.assertEqual("".join(sorted(out)), "012")
  138. self.assertEqual(err, "")
  139. def test_bind_start(self):
  140. code = textwrap.dedent(
  141. """
  142. import warnings
  143. from tornado.ioloop import IOLoop
  144. from tornado.process import task_id
  145. from tornado.tcpserver import TCPServer
  146. warnings.simplefilter("ignore", DeprecationWarning)
  147. server = TCPServer()
  148. server.bind(0, address='127.0.0.1')
  149. server.start(3)
  150. IOLoop.current().run_sync(lambda: None)
  151. print(task_id(), end='')
  152. """
  153. )
  154. out, err = self.run_subproc(code)
  155. self.assertEqual("".join(sorted(out)), "012")
  156. self.assertEqual(err, "")
  157. def test_add_sockets(self):
  158. code = textwrap.dedent(
  159. """
  160. import asyncio
  161. from tornado.netutil import bind_sockets
  162. from tornado.process import fork_processes, task_id
  163. from tornado.ioloop import IOLoop
  164. from tornado.tcpserver import TCPServer
  165. sockets = bind_sockets(0, address='127.0.0.1')
  166. fork_processes(3)
  167. async def post_fork_main():
  168. server = TCPServer()
  169. server.add_sockets(sockets)
  170. asyncio.run(post_fork_main())
  171. print(task_id(), end='')
  172. """
  173. )
  174. out, err = self.run_subproc(code)
  175. self.assertEqual("".join(sorted(out)), "012")
  176. self.assertEqual(err, "")
  177. def test_listen_multi_reuse_port(self):
  178. code = textwrap.dedent(
  179. """
  180. import asyncio
  181. import socket
  182. from tornado.netutil import bind_sockets
  183. from tornado.process import task_id, fork_processes
  184. from tornado.tcpserver import TCPServer
  185. # Pick an unused port which we will be able to bind to multiple times.
  186. (sock,) = bind_sockets(0, address='127.0.0.1',
  187. family=socket.AF_INET, reuse_port=True)
  188. port = sock.getsockname()[1]
  189. fork_processes(3)
  190. async def main():
  191. server = TCPServer()
  192. server.listen(port, address='127.0.0.1', reuse_port=True)
  193. asyncio.run(main())
  194. print(task_id(), end='')
  195. """
  196. )
  197. out, err = self.run_subproc(code)
  198. self.assertEqual("".join(sorted(out)), "012")
  199. self.assertEqual(err, "")