_transport.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import io
  16. import json
  17. import os
  18. import subprocess
  19. import sys
  20. from abc import ABC, abstractmethod
  21. from typing import Callable, Dict, Optional, Union
  22. from playwright._impl._driver import compute_driver_executable, get_driver_env
  23. from playwright._impl._helper import ParsedMessagePayload
  24. # Sourced from: https://github.com/pytest-dev/pytest/blob/da01ee0a4bb0af780167ecd228ab3ad249511302/src/_pytest/faulthandler.py#L69-L77
  25. def _get_stderr_fileno() -> Optional[int]:
  26. try:
  27. # when using pythonw, sys.stderr is None.
  28. # when Pyinstaller is used, there is no closed attribute because Pyinstaller monkey-patches it with a NullWriter class
  29. if sys.stderr is None or not hasattr(sys.stderr, "closed"):
  30. return None
  31. if sys.stderr.closed:
  32. return None
  33. return sys.stderr.fileno()
  34. except (NotImplementedError, AttributeError, io.UnsupportedOperation):
  35. # pytest-xdist monkeypatches sys.stderr with an object that is not an actual file.
  36. # https://docs.python.org/3/library/faulthandler.html#issue-with-file-descriptors
  37. # This is potentially dangerous, but the best we can do.
  38. if not hasattr(sys, "__stderr__") or not sys.__stderr__:
  39. return None
  40. return sys.__stderr__.fileno()
  41. class Transport(ABC):
  42. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  43. self._loop = loop
  44. self.on_message: Callable[[ParsedMessagePayload], None] = lambda _: None
  45. self.on_error_future: asyncio.Future = loop.create_future()
  46. @abstractmethod
  47. def request_stop(self) -> None:
  48. pass
  49. def dispose(self) -> None:
  50. pass
  51. @abstractmethod
  52. async def wait_until_stopped(self) -> None:
  53. pass
  54. @abstractmethod
  55. async def connect(self) -> None:
  56. pass
  57. @abstractmethod
  58. async def run(self) -> None:
  59. pass
  60. @abstractmethod
  61. def send(self, message: Dict) -> None:
  62. pass
  63. def serialize_message(self, message: Dict) -> bytes:
  64. msg = json.dumps(message)
  65. if "DEBUGP" in os.environ: # pragma: no cover
  66. print("\x1b[32mSEND>\x1b[0m", json.dumps(message, indent=2))
  67. return msg.encode()
  68. def deserialize_message(self, data: Union[str, bytes]) -> ParsedMessagePayload:
  69. obj = json.loads(data)
  70. if "DEBUGP" in os.environ: # pragma: no cover
  71. print("\x1b[33mRECV>\x1b[0m", json.dumps(obj, indent=2))
  72. return obj
  73. class PipeTransport(Transport):
  74. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  75. super().__init__(loop)
  76. self._stopped = False
  77. def request_stop(self) -> None:
  78. assert self._output
  79. self._stopped = True
  80. self._output.close()
  81. async def wait_until_stopped(self) -> None:
  82. await self._stopped_future
  83. async def connect(self) -> None:
  84. self._stopped_future: asyncio.Future = asyncio.Future()
  85. try:
  86. # For pyinstaller and Nuitka
  87. env = get_driver_env()
  88. if getattr(sys, "frozen", False) or globals().get("__compiled__"):
  89. env.setdefault("PLAYWRIGHT_BROWSERS_PATH", "0")
  90. startupinfo = None
  91. if sys.platform == "win32":
  92. startupinfo = subprocess.STARTUPINFO()
  93. startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
  94. startupinfo.wShowWindow = subprocess.SW_HIDE
  95. executable_path, entrypoint_path = compute_driver_executable()
  96. self._proc = await asyncio.create_subprocess_exec(
  97. executable_path,
  98. entrypoint_path,
  99. "run-driver",
  100. stdin=asyncio.subprocess.PIPE,
  101. stdout=asyncio.subprocess.PIPE,
  102. stderr=_get_stderr_fileno(),
  103. limit=32768,
  104. env=env,
  105. startupinfo=startupinfo,
  106. )
  107. except Exception as exc:
  108. self.on_error_future.set_exception(exc)
  109. raise exc
  110. self._output = self._proc.stdin
  111. async def run(self) -> None:
  112. assert self._proc.stdout
  113. assert self._proc.stdin
  114. while not self._stopped:
  115. try:
  116. buffer = await self._proc.stdout.readexactly(4)
  117. if self._stopped:
  118. break
  119. length = int.from_bytes(buffer, byteorder="little", signed=False)
  120. buffer = bytes(0)
  121. while length:
  122. to_read = min(length, 32768)
  123. data = await self._proc.stdout.readexactly(to_read)
  124. if self._stopped:
  125. break
  126. length -= to_read
  127. if len(buffer):
  128. buffer = buffer + data
  129. else:
  130. buffer = data
  131. if self._stopped:
  132. break
  133. obj = self.deserialize_message(buffer)
  134. self.on_message(obj)
  135. except asyncio.IncompleteReadError:
  136. if not self._stopped:
  137. self.on_error_future.set_exception(
  138. Exception("Connection closed while reading from the driver")
  139. )
  140. break
  141. await asyncio.sleep(0)
  142. await self._proc.communicate()
  143. self._stopped_future.set_result(None)
  144. def send(self, message: Dict) -> None:
  145. assert self._output
  146. data = self.serialize_message(message)
  147. self._output.write(
  148. len(data).to_bytes(4, byteorder="little", signed=False) + data
  149. )