concurrent_test.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. #
  2. # Copyright 2012 Facebook
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  5. # not use this file except in compliance with the License. You may obtain
  6. # a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  12. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  13. # License for the specific language governing permissions and limitations
  14. # under the License.
  15. from concurrent import futures
  16. import logging
  17. import re
  18. import socket
  19. import unittest
  20. from tornado.concurrent import (
  21. Future,
  22. chain_future,
  23. run_on_executor,
  24. future_set_result_unless_cancelled,
  25. )
  26. from tornado.escape import utf8, to_unicode
  27. from tornado import gen
  28. from tornado.iostream import IOStream
  29. from tornado.tcpserver import TCPServer
  30. from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
  31. class MiscFutureTest(AsyncTestCase):
  32. def test_future_set_result_unless_cancelled(self):
  33. fut = Future() # type: Future[int]
  34. future_set_result_unless_cancelled(fut, 42)
  35. self.assertEqual(fut.result(), 42)
  36. self.assertFalse(fut.cancelled())
  37. fut = Future()
  38. fut.cancel()
  39. is_cancelled = fut.cancelled()
  40. future_set_result_unless_cancelled(fut, 42)
  41. self.assertEqual(fut.cancelled(), is_cancelled)
  42. if not is_cancelled:
  43. self.assertEqual(fut.result(), 42)
  44. class ChainFutureTest(AsyncTestCase):
  45. @gen_test
  46. async def test_asyncio_futures(self):
  47. fut: Future[int] = Future()
  48. fut2: Future[int] = Future()
  49. chain_future(fut, fut2)
  50. fut.set_result(42)
  51. result = await fut2
  52. self.assertEqual(result, 42)
  53. @gen_test
  54. async def test_concurrent_futures(self):
  55. # A three-step chain: two concurrent futures (showing that both arguments to chain_future
  56. # can be concurrent futures), and then one from a concurrent future to an asyncio future so
  57. # we can use it in await.
  58. fut: futures.Future[int] = futures.Future()
  59. fut2: futures.Future[int] = futures.Future()
  60. fut3: Future[int] = Future()
  61. chain_future(fut, fut2)
  62. chain_future(fut2, fut3)
  63. fut.set_result(42)
  64. result = await fut3
  65. self.assertEqual(result, 42)
  66. # The following series of classes demonstrate and test various styles
  67. # of use, with and without generators and futures.
  68. class CapServer(TCPServer):
  69. @gen.coroutine
  70. def handle_stream(self, stream, address):
  71. data = yield stream.read_until(b"\n")
  72. data = to_unicode(data)
  73. if data == data.upper():
  74. stream.write(b"error\talready capitalized\n")
  75. else:
  76. # data already has \n
  77. stream.write(utf8("ok\t%s" % data.upper()))
  78. stream.close()
  79. class CapError(Exception):
  80. pass
  81. class BaseCapClient:
  82. def __init__(self, port):
  83. self.port = port
  84. def process_response(self, data):
  85. m = re.match("(.*)\t(.*)\n", to_unicode(data))
  86. if m is None:
  87. raise Exception("did not match")
  88. status, message = m.groups()
  89. if status == "ok":
  90. return message
  91. else:
  92. raise CapError(message)
  93. class GeneratorCapClient(BaseCapClient):
  94. @gen.coroutine
  95. def capitalize(self, request_data):
  96. logging.debug("capitalize")
  97. stream = IOStream(socket.socket())
  98. logging.debug("connecting")
  99. yield stream.connect(("127.0.0.1", self.port))
  100. stream.write(utf8(request_data + "\n"))
  101. logging.debug("reading")
  102. data = yield stream.read_until(b"\n")
  103. logging.debug("returning")
  104. stream.close()
  105. raise gen.Return(self.process_response(data))
  106. class GeneratorCapClientTest(AsyncTestCase):
  107. def setUp(self):
  108. super().setUp()
  109. self.server = CapServer()
  110. sock, port = bind_unused_port()
  111. self.server.add_sockets([sock])
  112. self.client = GeneratorCapClient(port=port)
  113. def tearDown(self):
  114. self.server.stop()
  115. super().tearDown()
  116. def test_future(self):
  117. future = self.client.capitalize("hello")
  118. self.io_loop.add_future(future, self.stop)
  119. self.wait()
  120. self.assertEqual(future.result(), "HELLO")
  121. def test_future_error(self):
  122. future = self.client.capitalize("HELLO")
  123. self.io_loop.add_future(future, self.stop)
  124. self.wait()
  125. self.assertRaisesRegex(CapError, "already capitalized", future.result)
  126. def test_generator(self):
  127. @gen.coroutine
  128. def f():
  129. result = yield self.client.capitalize("hello")
  130. self.assertEqual(result, "HELLO")
  131. self.io_loop.run_sync(f)
  132. def test_generator_error(self):
  133. @gen.coroutine
  134. def f():
  135. with self.assertRaisesRegex(CapError, "already capitalized"):
  136. yield self.client.capitalize("HELLO")
  137. self.io_loop.run_sync(f)
  138. class RunOnExecutorTest(AsyncTestCase):
  139. @gen_test
  140. def test_no_calling(self):
  141. class Object:
  142. def __init__(self):
  143. self.executor = futures.thread.ThreadPoolExecutor(1)
  144. @run_on_executor
  145. def f(self):
  146. return 42
  147. o = Object()
  148. answer = yield o.f()
  149. self.assertEqual(answer, 42)
  150. @gen_test
  151. def test_call_with_no_args(self):
  152. class Object:
  153. def __init__(self):
  154. self.executor = futures.thread.ThreadPoolExecutor(1)
  155. @run_on_executor()
  156. def f(self):
  157. return 42
  158. o = Object()
  159. answer = yield o.f()
  160. self.assertEqual(answer, 42)
  161. @gen_test
  162. def test_call_with_executor(self):
  163. class Object:
  164. def __init__(self):
  165. self.__executor = futures.thread.ThreadPoolExecutor(1)
  166. @run_on_executor(executor="_Object__executor")
  167. def f(self):
  168. return 42
  169. o = Object()
  170. answer = yield o.f()
  171. self.assertEqual(answer, 42)
  172. @gen_test
  173. def test_async_await(self):
  174. class Object:
  175. def __init__(self):
  176. self.executor = futures.thread.ThreadPoolExecutor(1)
  177. @run_on_executor()
  178. def f(self):
  179. return 42
  180. o = Object()
  181. async def f():
  182. answer = await o.f()
  183. return answer
  184. result = yield f()
  185. self.assertEqual(result, 42)
  186. if __name__ == "__main__":
  187. unittest.main()