_waiter.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import math
  16. import uuid
  17. from asyncio.tasks import Task
  18. from typing import Any, Callable, List, Tuple, Union
  19. from pyee import EventEmitter
  20. from playwright._impl._connection import ChannelOwner
  21. from playwright._impl._errors import Error, TimeoutError
  22. class Waiter:
  23. def __init__(self, channel_owner: ChannelOwner, event: str) -> None:
  24. self._result: asyncio.Future = asyncio.Future()
  25. self._wait_id = uuid.uuid4().hex
  26. self._loop = channel_owner._loop
  27. self._pending_tasks: List[Task] = []
  28. self._channel = channel_owner._channel
  29. self._registered_listeners: List[Tuple[EventEmitter, str, Callable]] = []
  30. self._logs: List[str] = []
  31. self._wait_for_event_info_before(self._wait_id, event)
  32. def _wait_for_event_info_before(self, wait_id: str, event: str) -> None:
  33. self._channel.send_no_reply(
  34. "waitForEventInfo",
  35. None,
  36. {
  37. "info": {
  38. "waitId": wait_id,
  39. "phase": "before",
  40. "event": event,
  41. }
  42. },
  43. )
  44. def _wait_for_event_info_after(self, wait_id: str, error: Exception = None) -> None:
  45. self._channel._connection.wrap_api_call_sync(
  46. lambda: self._channel.send_no_reply(
  47. "waitForEventInfo",
  48. None,
  49. {
  50. "info": {
  51. "waitId": wait_id,
  52. "phase": "after",
  53. **({"error": str(error)} if error else {}),
  54. },
  55. },
  56. ),
  57. True,
  58. )
  59. def reject_on_event(
  60. self,
  61. emitter: EventEmitter,
  62. event: str,
  63. error: Union[Error, Callable[..., Error]],
  64. predicate: Callable = None,
  65. ) -> None:
  66. def listener(event_data: Any = None) -> None:
  67. if not predicate or predicate(event_data):
  68. self._reject(error() if callable(error) else error)
  69. emitter.on(event, listener)
  70. self._registered_listeners.append((emitter, event, listener))
  71. def reject_on_timeout(self, timeout: float, message: str) -> None:
  72. if timeout == 0:
  73. return
  74. async def reject() -> None:
  75. await asyncio.sleep(timeout / 1000)
  76. self._reject(TimeoutError(message))
  77. self._pending_tasks.append(self._loop.create_task(reject()))
  78. def _cleanup(self) -> None:
  79. for task in self._pending_tasks:
  80. if not task.done():
  81. task.cancel()
  82. for listener in self._registered_listeners:
  83. listener[0].remove_listener(listener[1], listener[2])
  84. def _fulfill(self, result: Any) -> None:
  85. self._cleanup()
  86. if not self._result.done():
  87. self._result.set_result(result)
  88. self._wait_for_event_info_after(self._wait_id)
  89. def _reject(self, exception: Exception) -> None:
  90. self._cleanup()
  91. if exception:
  92. base_class = TimeoutError if isinstance(exception, TimeoutError) else Error
  93. exception = base_class(str(exception) + format_log_recording(self._logs))
  94. if not self._result.done():
  95. self._result.set_exception(exception)
  96. self._wait_for_event_info_after(self._wait_id, exception)
  97. def wait_for_event(
  98. self,
  99. emitter: EventEmitter,
  100. event: str,
  101. predicate: Callable = None,
  102. ) -> None:
  103. def listener(event_data: Any = None) -> None:
  104. if not predicate or predicate(event_data):
  105. self._fulfill(event_data)
  106. emitter.on(event, listener)
  107. self._registered_listeners.append((emitter, event, listener))
  108. def result(self) -> asyncio.Future:
  109. return self._result
  110. def log(self, message: str) -> None:
  111. self._logs.append(message)
  112. try:
  113. self._channel._connection.wrap_api_call_sync(
  114. lambda: self._channel.send_no_reply(
  115. "waitForEventInfo",
  116. None,
  117. {
  118. "info": {
  119. "waitId": self._wait_id,
  120. "phase": "log",
  121. "message": message,
  122. },
  123. },
  124. ),
  125. True,
  126. )
  127. except Exception:
  128. pass
  129. def throw_on_timeout(timeout: float, exception: Exception) -> asyncio.Task:
  130. async def throw() -> None:
  131. await asyncio.sleep(timeout / 1000)
  132. raise exception
  133. return asyncio.create_task(throw())
  134. def format_log_recording(log: List[str]) -> str:
  135. if not log:
  136. return ""
  137. header = " logs "
  138. header_length = 60
  139. left_length = math.floor((header_length - len(header)) / 2)
  140. right_length = header_length - len(header) - left_length
  141. new_line = "\n"
  142. return f"{new_line}{'=' * left_length}{header}{'=' * right_length}{new_line}{new_line.join(log)}{new_line}{'=' * header_length}"