test_handshake_large_response.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # -*- coding: utf-8 -*-
  2. import unittest
  3. from unittest.mock import Mock, patch
  4. from websocket._handshake import _get_resp_headers
  5. from websocket._exceptions import WebSocketBadStatusException
  6. from websocket._ssl_compat import SSLError
  7. """
  8. test_handshake_large_response.py
  9. websocket - WebSocket client library for Python
  10. Copyright 2025 engn33r
  11. Licensed under the Apache License, Version 2.0 (the "License");
  12. you may not use this file except in compliance with the License.
  13. You may obtain a copy of the License at
  14. http://www.apache.org/licenses/LICENSE-2.0
  15. Unless required by applicable law or agreed to in writing, software
  16. distributed under the License is distributed on an "AS IS" BASIS,
  17. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. See the License for the specific language governing permissions and
  19. limitations under the License.
  20. """
  21. class HandshakeLargeResponseTest(unittest.TestCase):
  22. def test_large_error_response_chunked_reading(self):
  23. """Test that large HTTP error responses during handshake are read in chunks"""
  24. # Mock socket
  25. mock_sock = Mock()
  26. # Create a large error response body (> 16KB)
  27. large_response = b"Error details: " + b"A" * 20000 # 20KB+ response
  28. # Track recv calls to ensure chunking
  29. recv_calls = []
  30. def mock_recv(sock, bufsize):
  31. recv_calls.append(bufsize)
  32. # Simulate SSL error if trying to read > 16KB at once
  33. if bufsize > 16384:
  34. raise SSLError("[SSL: BAD_LENGTH] unknown error")
  35. return large_response[:bufsize]
  36. # Mock read_headers to return error status with large content-length
  37. with patch("websocket._handshake.read_headers") as mock_read_headers:
  38. mock_read_headers.return_value = (
  39. 400, # Bad request status
  40. {"content-length": str(len(large_response))},
  41. "Bad Request",
  42. )
  43. # Mock the recv function to track calls
  44. with patch("websocket._socket.recv", side_effect=mock_recv):
  45. # This should not raise SSLError, but should raise WebSocketBadStatusException
  46. with self.assertRaises(WebSocketBadStatusException) as cm:
  47. _get_resp_headers(mock_sock)
  48. # Verify the response body was included in the exception
  49. self.assertIn(
  50. b"Error details:",
  51. (
  52. cm.exception.args[0].encode()
  53. if isinstance(cm.exception.args[0], str)
  54. else cm.exception.args[0]
  55. ),
  56. )
  57. # Verify chunked reading was used (multiple recv calls, none > 16KB)
  58. self.assertGreater(len(recv_calls), 1)
  59. self.assertTrue(all(call <= 16384 for call in recv_calls))
  60. def test_handshake_ssl_large_response_protection(self):
  61. """Test that the fix prevents SSL BAD_LENGTH errors during handshake"""
  62. mock_sock = Mock()
  63. # Large content that would trigger SSL error if read all at once
  64. large_content = b"X" * 32768 # 32KB
  65. chunks_returned = 0
  66. def mock_recv_chunked(sock, bufsize):
  67. nonlocal chunks_returned
  68. # Return data in chunks, simulating successful chunked reading
  69. chunk_start = chunks_returned * 16384
  70. chunk_end = min(chunk_start + bufsize, len(large_content))
  71. result = large_content[chunk_start:chunk_end]
  72. chunks_returned += 1 if result else 0
  73. return result
  74. with patch("websocket._handshake.read_headers") as mock_read_headers:
  75. mock_read_headers.return_value = (
  76. 500, # Server error
  77. {"content-length": str(len(large_content))},
  78. "Internal Server Error",
  79. )
  80. with patch("websocket._socket.recv", side_effect=mock_recv_chunked):
  81. # Should handle large response without SSL errors
  82. with self.assertRaises(WebSocketBadStatusException) as cm:
  83. _get_resp_headers(mock_sock)
  84. # Verify the complete response was captured
  85. exception_str = str(cm.exception)
  86. # Response body should be in the exception message
  87. self.assertIn("XXXXX", exception_str) # Part of the large content
  88. def test_handshake_normal_small_response(self):
  89. """Test that normal small responses still work correctly"""
  90. mock_sock = Mock()
  91. small_response = b"Small error message"
  92. def mock_recv(sock, bufsize):
  93. return small_response
  94. with patch("websocket._handshake.read_headers") as mock_read_headers:
  95. mock_read_headers.return_value = (
  96. 404, # Not found
  97. {"content-length": str(len(small_response))},
  98. "Not Found",
  99. )
  100. with patch("websocket._socket.recv", side_effect=mock_recv):
  101. with self.assertRaises(WebSocketBadStatusException) as cm:
  102. _get_resp_headers(mock_sock)
  103. # Verify small response is handled correctly
  104. self.assertIn("Small error message", str(cm.exception))
  105. def test_handshake_no_content_length(self):
  106. """Test handshake error response without content-length header"""
  107. mock_sock = Mock()
  108. with patch("websocket._handshake.read_headers") as mock_read_headers:
  109. mock_read_headers.return_value = (
  110. 403, # Forbidden
  111. {}, # No content-length header
  112. "Forbidden",
  113. )
  114. # Should raise exception without trying to read response body
  115. with self.assertRaises(WebSocketBadStatusException) as cm:
  116. _get_resp_headers(mock_sock)
  117. # Should mention status but not have response body
  118. exception_str = str(cm.exception)
  119. self.assertIn("403", exception_str)
  120. self.assertIn("Forbidden", exception_str)
  121. if __name__ == "__main__":
  122. unittest.main()