_connection.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import collections.abc
  16. import contextvars
  17. import datetime
  18. import inspect
  19. import sys
  20. import traceback
  21. from pathlib import Path
  22. from typing import (
  23. TYPE_CHECKING,
  24. Any,
  25. Callable,
  26. Dict,
  27. List,
  28. Mapping,
  29. Optional,
  30. TypedDict,
  31. Union,
  32. cast,
  33. )
  34. from pyee import EventEmitter
  35. from pyee.asyncio import AsyncIOEventEmitter
  36. import playwright
  37. import playwright._impl._impl_to_api_mapping
  38. from playwright._impl._errors import TargetClosedError, rewrite_error
  39. from playwright._impl._greenlets import EventGreenlet
  40. from playwright._impl._helper import Error, ParsedMessagePayload, parse_error
  41. from playwright._impl._transport import Transport
  42. if TYPE_CHECKING:
  43. from playwright._impl._local_utils import LocalUtils
  44. from playwright._impl._playwright import Playwright
  45. TimeoutCalculator = Optional[Callable[[Optional[float]], float]]
  46. class Channel(AsyncIOEventEmitter):
  47. def __init__(self, connection: "Connection", object: "ChannelOwner") -> None:
  48. super().__init__()
  49. self._connection = connection
  50. self._guid = object._guid
  51. self._object = object
  52. self.on("error", lambda exc: self._connection._on_event_listener_error(exc))
  53. async def send(
  54. self,
  55. method: str,
  56. timeout_calculator: TimeoutCalculator,
  57. params: Dict = None,
  58. is_internal: bool = False,
  59. title: str = None,
  60. ) -> Any:
  61. return await self._connection.wrap_api_call(
  62. lambda: self._inner_send(method, timeout_calculator, params, False),
  63. is_internal,
  64. title,
  65. )
  66. async def send_return_as_dict(
  67. self,
  68. method: str,
  69. timeout_calculator: TimeoutCalculator,
  70. params: Dict = None,
  71. is_internal: bool = False,
  72. title: str = None,
  73. ) -> Any:
  74. return await self._connection.wrap_api_call(
  75. lambda: self._inner_send(method, timeout_calculator, params, True),
  76. is_internal,
  77. title,
  78. )
  79. def send_no_reply(
  80. self,
  81. method: str,
  82. timeout_calculator: TimeoutCalculator,
  83. params: Dict = None,
  84. is_internal: bool = False,
  85. title: str = None,
  86. ) -> None:
  87. # No reply messages are used to e.g. waitForEventInfo(after).
  88. self._connection.wrap_api_call_sync(
  89. lambda: self._connection._send_message_to_server(
  90. self._object,
  91. method,
  92. _augment_params(params, timeout_calculator),
  93. True,
  94. ),
  95. is_internal,
  96. title,
  97. )
  98. async def _inner_send(
  99. self,
  100. method: str,
  101. timeout_calculator: TimeoutCalculator,
  102. params: Optional[Dict],
  103. return_as_dict: bool,
  104. ) -> Any:
  105. if self._connection._error:
  106. error = self._connection._error
  107. self._connection._error = None
  108. raise error
  109. callback = self._connection._send_message_to_server(
  110. self._object, method, _augment_params(params, timeout_calculator)
  111. )
  112. done, _ = await asyncio.wait(
  113. {
  114. self._connection._transport.on_error_future,
  115. callback.future,
  116. },
  117. return_when=asyncio.FIRST_COMPLETED,
  118. )
  119. if not callback.future.done():
  120. callback.future.cancel()
  121. result = next(iter(done)).result()
  122. # Protocol now has named return values, assume result is one level deeper unless
  123. # there is explicit ambiguity.
  124. if not result:
  125. return None
  126. assert isinstance(result, dict)
  127. if return_as_dict:
  128. return result
  129. if len(result) == 0:
  130. return None
  131. assert len(result) == 1
  132. key = next(iter(result))
  133. return result[key]
  134. class ChannelOwner(AsyncIOEventEmitter):
  135. def __init__(
  136. self,
  137. parent: Union["ChannelOwner", "Connection"],
  138. type: str,
  139. guid: str,
  140. initializer: Dict,
  141. ) -> None:
  142. super().__init__(loop=parent._loop)
  143. self._loop: asyncio.AbstractEventLoop = parent._loop
  144. self._dispatcher_fiber: Any = parent._dispatcher_fiber
  145. self._type = type
  146. self._guid: str = guid
  147. self._connection: Connection = (
  148. parent._connection if isinstance(parent, ChannelOwner) else parent
  149. )
  150. self._parent: Optional[ChannelOwner] = (
  151. parent if isinstance(parent, ChannelOwner) else None
  152. )
  153. self._objects: Dict[str, "ChannelOwner"] = {}
  154. self._channel: Channel = Channel(self._connection, self)
  155. self._initializer = initializer
  156. self._was_collected = False
  157. self._connection._objects[guid] = self
  158. if self._parent:
  159. self._parent._objects[guid] = self
  160. self._event_to_subscription_mapping: Dict[str, str] = {}
  161. def _dispose(self, reason: Optional[str]) -> None:
  162. # Clean up from parent and connection.
  163. if self._parent:
  164. del self._parent._objects[self._guid]
  165. del self._connection._objects[self._guid]
  166. self._was_collected = reason == "gc"
  167. # Dispose all children.
  168. for object in list(self._objects.values()):
  169. object._dispose(reason)
  170. self._objects.clear()
  171. def _adopt(self, child: "ChannelOwner") -> None:
  172. del cast("ChannelOwner", child._parent)._objects[child._guid]
  173. self._objects[child._guid] = child
  174. child._parent = self
  175. def _set_event_to_subscription_mapping(self, mapping: Dict[str, str]) -> None:
  176. self._event_to_subscription_mapping = mapping
  177. def _update_subscription(self, event: str, enabled: bool) -> None:
  178. protocol_event = self._event_to_subscription_mapping.get(event)
  179. if protocol_event:
  180. self._connection.wrap_api_call_sync(
  181. lambda: self._channel.send_no_reply(
  182. "updateSubscription",
  183. None,
  184. {"event": protocol_event, "enabled": enabled},
  185. ),
  186. True,
  187. )
  188. def _add_event_handler(self, event: str, k: Any, v: Any) -> None:
  189. if not self.listeners(event):
  190. self._update_subscription(event, True)
  191. super()._add_event_handler(event, k, v)
  192. def remove_listener(self, event: str, f: Any) -> None:
  193. super().remove_listener(event, f)
  194. if not self.listeners(event):
  195. self._update_subscription(event, False)
  196. class ProtocolCallback:
  197. def __init__(self, loop: asyncio.AbstractEventLoop, no_reply: bool = False) -> None:
  198. self.stack_trace: traceback.StackSummary
  199. self.no_reply = no_reply
  200. self.future = loop.create_future()
  201. if no_reply:
  202. self.future.set_result(None)
  203. # The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
  204. current_task = asyncio.current_task()
  205. def cb(task: asyncio.Task) -> None:
  206. if current_task:
  207. current_task.remove_done_callback(cb)
  208. if task.cancelled():
  209. self.future.cancel()
  210. if current_task:
  211. current_task.add_done_callback(cb)
  212. self.future.add_done_callback(
  213. lambda _: (
  214. current_task.remove_done_callback(cb) if current_task else None
  215. )
  216. )
  217. class RootChannelOwner(ChannelOwner):
  218. def __init__(self, connection: "Connection") -> None:
  219. super().__init__(connection, "Root", "", {})
  220. async def initialize(self) -> "Playwright":
  221. return from_channel(
  222. await self._channel.send(
  223. "initialize",
  224. None,
  225. {
  226. "sdkLanguage": "python",
  227. },
  228. )
  229. )
  230. class Connection(EventEmitter):
  231. def __init__(
  232. self,
  233. dispatcher_fiber: Any,
  234. object_factory: Callable[[ChannelOwner, str, str, Dict], ChannelOwner],
  235. transport: Transport,
  236. loop: asyncio.AbstractEventLoop,
  237. local_utils: Optional["LocalUtils"] = None,
  238. ) -> None:
  239. super().__init__()
  240. self._dispatcher_fiber = dispatcher_fiber
  241. self._transport = transport
  242. self._transport.on_message = lambda msg: self.dispatch(msg)
  243. self._waiting_for_object: Dict[str, Callable[[ChannelOwner], None]] = {}
  244. self._last_id = 0
  245. self._objects: Dict[str, ChannelOwner] = {}
  246. self._callbacks: Dict[int, ProtocolCallback] = {}
  247. self._object_factory = object_factory
  248. self._is_sync = False
  249. self._child_ws_connections: List["Connection"] = []
  250. self._loop = loop
  251. self.playwright_future: asyncio.Future["Playwright"] = loop.create_future()
  252. self._error: Optional[BaseException] = None
  253. self.is_remote = False
  254. self._init_task: Optional[asyncio.Task] = None
  255. self._api_zone: contextvars.ContextVar[Optional[ParsedStackTrace]] = (
  256. contextvars.ContextVar("ApiZone", default=None)
  257. )
  258. self._local_utils: Optional["LocalUtils"] = local_utils
  259. self._tracing_count = 0
  260. self._closed_error: Optional[Exception] = None
  261. @property
  262. def local_utils(self) -> "LocalUtils":
  263. assert self._local_utils
  264. return self._local_utils
  265. def mark_as_remote(self) -> None:
  266. self.is_remote = True
  267. async def run_as_sync(self) -> None:
  268. self._is_sync = True
  269. await self.run()
  270. async def run(self) -> None:
  271. self._loop = asyncio.get_running_loop()
  272. self._root_object = RootChannelOwner(self)
  273. async def init() -> None:
  274. self.playwright_future.set_result(await self._root_object.initialize())
  275. await self._transport.connect()
  276. self._init_task = self._loop.create_task(init())
  277. await self._transport.run()
  278. def stop_sync(self) -> None:
  279. self._transport.request_stop()
  280. self._dispatcher_fiber.switch()
  281. self._loop.run_until_complete(self._transport.wait_until_stopped())
  282. self.cleanup()
  283. async def stop_async(self) -> None:
  284. self._transport.request_stop()
  285. await self._transport.wait_until_stopped()
  286. self.cleanup()
  287. def cleanup(self, cause: str = None) -> None:
  288. self._closed_error = TargetClosedError(cause) if cause else TargetClosedError()
  289. if self._init_task and not self._init_task.done():
  290. self._init_task.cancel()
  291. for ws_connection in self._child_ws_connections:
  292. ws_connection._transport.dispose()
  293. for callback in self._callbacks.values():
  294. # To prevent 'Future exception was never retrieved' we ignore all callbacks that are no_reply.
  295. if callback.no_reply:
  296. continue
  297. if callback.future.cancelled():
  298. continue
  299. callback.future.set_exception(self._closed_error)
  300. self._callbacks.clear()
  301. self.emit("close")
  302. def call_on_object_with_known_name(
  303. self, guid: str, callback: Callable[[ChannelOwner], None]
  304. ) -> None:
  305. self._waiting_for_object[guid] = callback
  306. def set_is_tracing(self, is_tracing: bool) -> None:
  307. if is_tracing:
  308. self._tracing_count += 1
  309. else:
  310. self._tracing_count -= 1
  311. def _send_message_to_server(
  312. self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False
  313. ) -> ProtocolCallback:
  314. if self._closed_error:
  315. raise self._closed_error
  316. if object._was_collected:
  317. raise Error(
  318. "The object has been collected to prevent unbounded heap growth."
  319. )
  320. self._last_id += 1
  321. id = self._last_id
  322. callback = ProtocolCallback(self._loop, no_reply=no_reply)
  323. task = asyncio.current_task(self._loop)
  324. callback.stack_trace = cast(
  325. traceback.StackSummary,
  326. getattr(task, "__pw_stack_trace__", traceback.extract_stack(limit=10)),
  327. )
  328. callback.no_reply = no_reply
  329. stack_trace_information = cast(ParsedStackTrace, self._api_zone.get())
  330. frames = stack_trace_information.get("frames", [])
  331. location = (
  332. {
  333. "file": frames[0]["file"],
  334. "line": frames[0]["line"],
  335. "column": frames[0]["column"],
  336. }
  337. if frames
  338. else None
  339. )
  340. metadata = {
  341. "wallTime": int(datetime.datetime.now().timestamp() * 1000),
  342. "apiName": stack_trace_information["apiName"],
  343. "internal": not stack_trace_information["apiName"],
  344. }
  345. if location:
  346. metadata["location"] = location # type: ignore
  347. title = stack_trace_information["title"]
  348. if title:
  349. metadata["title"] = title
  350. message = {
  351. "id": id,
  352. "guid": object._guid,
  353. "method": method,
  354. "params": self._replace_channels_with_guids(params),
  355. "metadata": metadata,
  356. }
  357. if self._tracing_count > 0 and frames and object._guid != "localUtils":
  358. self.local_utils.add_stack_to_tracing_no_reply(id, frames)
  359. self._callbacks[id] = callback
  360. self._transport.send(message)
  361. return callback
  362. def dispatch(self, msg: ParsedMessagePayload) -> None:
  363. if self._closed_error:
  364. return
  365. id = msg.get("id")
  366. if id:
  367. callback = self._callbacks.pop(id)
  368. if callback.future.cancelled():
  369. return
  370. # No reply messages are used to e.g. waitForEventInfo(after) which returns exceptions on page close.
  371. # To prevent 'Future exception was never retrieved' we just ignore such messages.
  372. if callback.no_reply:
  373. return
  374. error = msg.get("error")
  375. if error and not msg.get("result"):
  376. parsed_error = parse_error(
  377. error["error"], format_call_log(msg.get("log")) # type: ignore
  378. )
  379. parsed_error._stack = "".join(callback.stack_trace.format())
  380. callback.future.set_exception(parsed_error)
  381. else:
  382. result = self._replace_guids_with_channels(msg.get("result"))
  383. callback.future.set_result(result)
  384. return
  385. guid = msg["guid"]
  386. method = msg["method"]
  387. params = msg.get("params")
  388. if method == "__create__":
  389. assert params
  390. parent = self._objects[guid]
  391. self._create_remote_object(
  392. parent, params["type"], params["guid"], params["initializer"]
  393. )
  394. return
  395. object = self._objects.get(guid)
  396. if not object:
  397. raise Exception(f'Cannot find object to "{method}": {guid}')
  398. if method == "__adopt__":
  399. child_guid = cast(Dict[str, str], params)["guid"]
  400. child = self._objects.get(child_guid)
  401. if not child:
  402. raise Exception(f"Unknown new child: {child_guid}")
  403. object._adopt(child)
  404. return
  405. if method == "__dispose__":
  406. assert isinstance(params, dict)
  407. self._objects[guid]._dispose(cast(Optional[str], params.get("reason")))
  408. return
  409. object = self._objects[guid]
  410. should_replace_guids_with_channels = "jsonPipe@" not in guid
  411. try:
  412. if self._is_sync:
  413. for listener in object._channel.listeners(method):
  414. # Event handlers like route/locatorHandlerTriggered require us to perform async work.
  415. # In order to report their potential errors to the user, we need to catch it and store it in the connection
  416. def _done_callback(future: asyncio.Future) -> None:
  417. exc = future.exception()
  418. if exc:
  419. self._on_event_listener_error(exc)
  420. def _listener_with_error_handler_attached(params: Any) -> None:
  421. potential_future = listener(params)
  422. if asyncio.isfuture(potential_future):
  423. potential_future.add_done_callback(_done_callback)
  424. # Each event handler is a potentilly blocking context, create a fiber for each
  425. # and switch to them in order, until they block inside and pass control to each
  426. # other and then eventually back to dispatcher as listener functions return.
  427. g = EventGreenlet(_listener_with_error_handler_attached)
  428. if should_replace_guids_with_channels:
  429. g.switch(self._replace_guids_with_channels(params))
  430. else:
  431. g.switch(params)
  432. else:
  433. if should_replace_guids_with_channels:
  434. object._channel.emit(
  435. method, self._replace_guids_with_channels(params)
  436. )
  437. else:
  438. object._channel.emit(method, params)
  439. except BaseException as exc:
  440. self._on_event_listener_error(exc)
  441. def _on_event_listener_error(self, exc: BaseException) -> None:
  442. print("Error occurred in event listener", file=sys.stderr)
  443. traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
  444. # Save the error to throw at the next API call. This "replicates" unhandled rejection in Node.js.
  445. self._error = exc
  446. def _create_remote_object(
  447. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  448. ) -> ChannelOwner:
  449. initializer = self._replace_guids_with_channels(initializer)
  450. result = self._object_factory(parent, type, guid, initializer)
  451. if guid in self._waiting_for_object:
  452. self._waiting_for_object.pop(guid)(result)
  453. return result
  454. def _replace_channels_with_guids(
  455. self,
  456. payload: Any,
  457. ) -> Any:
  458. if payload is None:
  459. return payload
  460. if isinstance(payload, Path):
  461. return str(payload)
  462. if isinstance(payload, collections.abc.Sequence) and not isinstance(
  463. payload, str
  464. ):
  465. return list(map(self._replace_channels_with_guids, payload))
  466. if isinstance(payload, Channel):
  467. return dict(guid=payload._guid)
  468. if isinstance(payload, dict):
  469. result = {}
  470. for key, value in payload.items():
  471. result[key] = self._replace_channels_with_guids(value)
  472. return result
  473. return payload
  474. def _replace_guids_with_channels(self, payload: Any) -> Any:
  475. if payload is None:
  476. return payload
  477. if isinstance(payload, list):
  478. return list(map(self._replace_guids_with_channels, payload))
  479. if isinstance(payload, dict):
  480. if payload.get("guid") in self._objects:
  481. return self._objects[payload["guid"]]._channel
  482. result = {}
  483. for key, value in payload.items():
  484. result[key] = self._replace_guids_with_channels(value)
  485. return result
  486. return payload
  487. async def wrap_api_call(
  488. self, cb: Callable[[], Any], is_internal: bool = False, title: str = None
  489. ) -> Any:
  490. if self._api_zone.get():
  491. return await cb()
  492. task = asyncio.current_task(self._loop)
  493. st: List[inspect.FrameInfo] = getattr(
  494. task, "__pw_stack__", None
  495. ) or inspect.stack(0)
  496. parsed_st = _extract_stack_trace_information_from_stack(st, is_internal, title)
  497. self._api_zone.set(parsed_st)
  498. try:
  499. return await cb()
  500. except Exception as error:
  501. raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
  502. finally:
  503. self._api_zone.set(None)
  504. def wrap_api_call_sync(
  505. self, cb: Callable[[], Any], is_internal: bool = False, title: str = None
  506. ) -> Any:
  507. if self._api_zone.get():
  508. return cb()
  509. task = asyncio.current_task(self._loop)
  510. st: List[inspect.FrameInfo] = getattr(
  511. task, "__pw_stack__", None
  512. ) or inspect.stack(0)
  513. parsed_st = _extract_stack_trace_information_from_stack(st, is_internal, title)
  514. self._api_zone.set(parsed_st)
  515. try:
  516. return cb()
  517. except Exception as error:
  518. raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
  519. finally:
  520. self._api_zone.set(None)
  521. def from_channel(channel: Channel) -> Any:
  522. return channel._object
  523. def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
  524. return channel._object if channel else None
  525. class StackFrame(TypedDict):
  526. file: str
  527. line: int
  528. column: int
  529. function: Optional[str]
  530. class ParsedStackTrace(TypedDict):
  531. frames: List[StackFrame]
  532. apiName: Optional[str]
  533. title: Optional[str]
  534. def _extract_stack_trace_information_from_stack(
  535. st: List[inspect.FrameInfo], is_internal: bool, title: str = None
  536. ) -> ParsedStackTrace:
  537. playwright_module_path = str(Path(playwright.__file__).parents[0])
  538. last_internal_api_name = ""
  539. api_name = ""
  540. parsed_frames: List[StackFrame] = []
  541. for frame in st:
  542. # Sync and Async implementations can have event handlers. When these are sync, they
  543. # get evaluated in the context of the event loop, so they contain the stack trace of when
  544. # the message was received. _impl_to_api_mapping is glue between the user-code and internal
  545. # code to translate impl classes to api classes. We want to ignore these frames.
  546. if playwright._impl._impl_to_api_mapping.__file__ == frame.filename:
  547. continue
  548. is_playwright_internal = frame.filename.startswith(playwright_module_path)
  549. method_name = ""
  550. if "self" in frame[0].f_locals:
  551. method_name = frame[0].f_locals["self"].__class__.__name__ + "."
  552. method_name += frame[0].f_code.co_name
  553. if not is_playwright_internal:
  554. parsed_frames.append(
  555. {
  556. "file": frame.filename,
  557. "line": frame.lineno,
  558. "column": 0,
  559. "function": method_name,
  560. }
  561. )
  562. if is_playwright_internal:
  563. last_internal_api_name = method_name
  564. elif last_internal_api_name:
  565. api_name = last_internal_api_name
  566. last_internal_api_name = ""
  567. if not api_name:
  568. api_name = last_internal_api_name
  569. return {
  570. "frames": parsed_frames,
  571. "apiName": "" if is_internal else api_name,
  572. "title": title,
  573. }
  574. def _augment_params(
  575. params: Optional[Dict],
  576. timeout_calculator: Optional[Callable[[Optional[float]], float]],
  577. ) -> Dict:
  578. if params is None:
  579. params = {}
  580. if timeout_calculator:
  581. params["timeout"] = timeout_calculator(params.get("timeout"))
  582. return _filter_none(params)
  583. def _filter_none(d: Mapping) -> Dict:
  584. result = {}
  585. for k, v in d.items():
  586. if v is None:
  587. continue
  588. result[k] = _filter_none(v) if isinstance(v, dict) else v
  589. return result
  590. def format_call_log(log: Optional[List[str]]) -> str:
  591. if not log:
  592. return ""
  593. if len(list(filter(lambda x: x.strip(), log))) == 0:
  594. return ""
  595. return "\nCall log:\n" + "\n".join(log) + "\n"