netutil_test.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import errno
  2. import signal
  3. import socket
  4. from subprocess import Popen
  5. import sys
  6. import time
  7. import unittest
  8. from tornado.netutil import (
  9. BlockingResolver,
  10. OverrideResolver,
  11. ThreadedResolver,
  12. is_valid_ip,
  13. bind_sockets,
  14. )
  15. from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
  16. from tornado.test.util import skipIfNoNetwork, abstract_base_test
  17. import typing
  18. try:
  19. import pycares # type: ignore
  20. except ImportError:
  21. pycares = None
  22. else:
  23. from tornado.platform.caresresolver import CaresResolver
  24. @abstract_base_test
  25. class _ResolverTestMixin(AsyncTestCase):
  26. resolver = None # type: typing.Any
  27. @gen_test
  28. def test_localhost(self):
  29. addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
  30. # Most of the time localhost resolves to either the ipv4 loopback
  31. # address alone, or ipv4+ipv6. But some versions of pycares will only
  32. # return the ipv6 version, so we have to check for either one alone.
  33. self.assertTrue(
  34. ((socket.AF_INET, ("127.0.0.1", 80)) in addrinfo)
  35. or ((socket.AF_INET6, ("::1", 80)) in addrinfo),
  36. f"loopback address not found in {addrinfo}",
  37. )
  38. # It is impossible to quickly and consistently generate an error in name
  39. # resolution, so test this case separately, using mocks as needed.
  40. @abstract_base_test
  41. class _ResolverErrorTestMixin(AsyncTestCase):
  42. resolver = None # type: typing.Any
  43. @gen_test
  44. def test_bad_host(self):
  45. with self.assertRaises(IOError):
  46. yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
  47. def _failing_getaddrinfo(*args):
  48. """Dummy implementation of getaddrinfo for use in mocks"""
  49. raise socket.gaierror(errno.EIO, "mock: lookup failed")
  50. @skipIfNoNetwork
  51. class BlockingResolverTest(_ResolverTestMixin):
  52. def setUp(self):
  53. super().setUp()
  54. self.resolver = BlockingResolver()
  55. # getaddrinfo-based tests need mocking to reliably generate errors;
  56. # some configurations are slow to produce errors and take longer than
  57. # our default timeout.
  58. class BlockingResolverErrorTest(_ResolverErrorTestMixin):
  59. def setUp(self):
  60. super().setUp()
  61. self.resolver = BlockingResolver()
  62. self.real_getaddrinfo = socket.getaddrinfo
  63. socket.getaddrinfo = _failing_getaddrinfo
  64. def tearDown(self):
  65. socket.getaddrinfo = self.real_getaddrinfo
  66. super().tearDown()
  67. class OverrideResolverTest(_ResolverTestMixin):
  68. def setUp(self):
  69. super().setUp()
  70. mapping = {
  71. ("google.com", 80): ("1.2.3.4", 80),
  72. ("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
  73. ("google.com", 80, socket.AF_INET6): (
  74. "2a02:6b8:7c:40c:c51e:495f:e23a:3",
  75. 80,
  76. ),
  77. }
  78. self.resolver = OverrideResolver(BlockingResolver(), mapping)
  79. @gen_test
  80. def test_resolve_multiaddr(self):
  81. result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
  82. self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
  83. result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
  84. self.assertIn(
  85. (socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
  86. )
  87. @skipIfNoNetwork
  88. class ThreadedResolverTest(_ResolverTestMixin):
  89. def setUp(self):
  90. super().setUp()
  91. self.resolver = ThreadedResolver()
  92. def tearDown(self):
  93. self.resolver.close()
  94. super().tearDown()
  95. class ThreadedResolverErrorTest(_ResolverErrorTestMixin):
  96. def setUp(self):
  97. super().setUp()
  98. self.resolver = BlockingResolver()
  99. self.real_getaddrinfo = socket.getaddrinfo
  100. socket.getaddrinfo = _failing_getaddrinfo
  101. def tearDown(self):
  102. socket.getaddrinfo = self.real_getaddrinfo
  103. super().tearDown()
  104. @skipIfNoNetwork
  105. @unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
  106. class ThreadedResolverImportTest(unittest.TestCase):
  107. def test_import(self):
  108. TIMEOUT = 5
  109. # Test for a deadlock when importing a module that runs the
  110. # ThreadedResolver at import-time. See resolve_test.py for
  111. # full explanation.
  112. command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
  113. start = time.time()
  114. popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
  115. while time.time() - start < TIMEOUT:
  116. return_code = popen.poll()
  117. if return_code is not None:
  118. self.assertEqual(0, return_code)
  119. return # Success.
  120. time.sleep(0.05)
  121. self.fail("import timed out")
  122. # We do not test errors with CaresResolver:
  123. # Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
  124. # with an NXDOMAIN status code. Most resolvers treat this as an error;
  125. # C-ares returns the results, making the "bad_host" tests unreliable.
  126. # C-ares will try to resolve even malformed names, such as the
  127. # name with spaces used in this test.
  128. @skipIfNoNetwork
  129. @unittest.skipIf(pycares is None, "pycares module not present")
  130. @unittest.skipIf(sys.platform == "win32", "pycares doesn't return loopback on windows")
  131. @unittest.skipIf(sys.platform == "darwin", "pycares doesn't return 127.0.0.1 on darwin")
  132. class CaresResolverTest(_ResolverTestMixin):
  133. def setUp(self):
  134. super().setUp()
  135. self.resolver = CaresResolver()
  136. class IsValidIPTest(unittest.TestCase):
  137. def test_is_valid_ip(self):
  138. self.assertTrue(is_valid_ip("127.0.0.1"))
  139. self.assertTrue(is_valid_ip("4.4.4.4"))
  140. self.assertTrue(is_valid_ip("::1"))
  141. self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
  142. self.assertFalse(is_valid_ip("www.google.com"))
  143. self.assertFalse(is_valid_ip("localhost"))
  144. self.assertFalse(is_valid_ip("4.4.4.4<"))
  145. self.assertFalse(is_valid_ip(" 127.0.0.1"))
  146. self.assertFalse(is_valid_ip(""))
  147. self.assertFalse(is_valid_ip(" "))
  148. self.assertFalse(is_valid_ip("\n"))
  149. self.assertFalse(is_valid_ip("\x00"))
  150. self.assertFalse(is_valid_ip("a" * 100))
  151. class TestPortAllocation(unittest.TestCase):
  152. def test_same_port_allocation(self):
  153. sockets = bind_sockets(0, "localhost")
  154. try:
  155. port = sockets[0].getsockname()[1]
  156. self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
  157. finally:
  158. for sock in sockets:
  159. sock.close()
  160. @unittest.skipIf(
  161. not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
  162. )
  163. def test_reuse_port(self):
  164. sockets: typing.List[socket.socket] = []
  165. sock, port = bind_unused_port(reuse_port=True)
  166. try:
  167. sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
  168. self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
  169. finally:
  170. sock.close()
  171. for sock in sockets:
  172. sock.close()