test_stdio.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import asyncio
  2. import subprocess
  3. import sys
  4. import pytest
  5. from tornado.queues import Queue
  6. from jupyter_lsp.stdio import LspStdIoReader
  7. WRITER_TEMPLATE = """
  8. from time import sleep
  9. print('Content-Length: {length}')
  10. print()
  11. for repeat in range({repeats}):
  12. sleep({interval})
  13. print('{message}', end='')
  14. if {add_excess}:
  15. print("extra", end='')
  16. print()
  17. """
  18. class CommunicatorSpawner:
  19. def __init__(self, tmp_path):
  20. self.tmp_path = tmp_path
  21. def spawn_writer(
  22. self, message: str, repeats: int = 1, interval=None, add_excess=False
  23. ):
  24. length = len(message) * repeats
  25. commands_file = self.tmp_path / "writer.py"
  26. commands_file.write_text(
  27. WRITER_TEMPLATE.format(
  28. length=length,
  29. repeats=repeats,
  30. interval=interval or 0,
  31. message=message,
  32. add_excess=add_excess,
  33. )
  34. )
  35. return subprocess.Popen(
  36. [sys.executable, "-u", str(commands_file)],
  37. stdout=subprocess.PIPE,
  38. bufsize=0,
  39. )
  40. @pytest.fixture
  41. def communicator_spawner(tmp_path):
  42. return CommunicatorSpawner(tmp_path)
  43. async def join_process(process: subprocess.Popen, headstart=1, timeout=1):
  44. await asyncio.sleep(headstart)
  45. result = process.wait(timeout=timeout)
  46. if process.stdout:
  47. process.stdout.close()
  48. return result
  49. @pytest.mark.parametrize(
  50. "message,repeats,interval,add_excess",
  51. [
  52. ["short", 1, None, False],
  53. ["ab" * 10_0000, 1, None, False],
  54. ["ab", 2, 0.01, False],
  55. ["ab", 45, 0.01, False],
  56. ["message", 2, 0.01, True],
  57. ],
  58. ids=["short", "long", "intermittent", "intensive-intermittent", "with-excess"],
  59. )
  60. @pytest.mark.asyncio
  61. async def test_reader(message, repeats, interval, add_excess, communicator_spawner):
  62. queue = Queue()
  63. process = communicator_spawner.spawn_writer(
  64. message=message, repeats=repeats, interval=interval, add_excess=add_excess
  65. )
  66. reader = LspStdIoReader(stream=process.stdout, queue=queue)
  67. await asyncio.gather(join_process(process, headstart=3, timeout=1), reader.read())
  68. result = queue.get_nowait()
  69. assert result == message * repeats