_context_manager.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from typing import TYPE_CHECKING, Any, Optional, cast
  16. from greenlet import greenlet
  17. from playwright._impl._connection import ChannelOwner, Connection
  18. from playwright._impl._errors import Error
  19. from playwright._impl._greenlets import MainGreenlet
  20. from playwright._impl._object_factory import create_remote_object
  21. from playwright._impl._playwright import Playwright
  22. from playwright._impl._transport import PipeTransport
  23. from playwright.sync_api._generated import Playwright as SyncPlaywright
  24. if TYPE_CHECKING:
  25. from asyncio.unix_events import AbstractChildWatcher
  26. class PlaywrightContextManager:
  27. def __init__(self) -> None:
  28. self._playwright: SyncPlaywright
  29. self._loop: asyncio.AbstractEventLoop
  30. self._own_loop = False
  31. self._watcher: Optional[AbstractChildWatcher] = None
  32. self._exit_was_called = False
  33. def __enter__(self) -> SyncPlaywright:
  34. try:
  35. self._loop = asyncio.get_running_loop()
  36. except RuntimeError:
  37. self._loop = asyncio.new_event_loop()
  38. self._own_loop = True
  39. if self._loop.is_running():
  40. raise Error(
  41. """It looks like you are using Playwright Sync API inside the asyncio loop.
  42. Please use the Async API instead."""
  43. )
  44. # Create a new fiber for the protocol dispatcher. It will be pumping events
  45. # until the end of times. We will pass control to that fiber every time we
  46. # block while waiting for a response.
  47. def greenlet_main() -> None:
  48. self._loop.run_until_complete(self._connection.run_as_sync())
  49. dispatcher_fiber = MainGreenlet(greenlet_main)
  50. self._connection = Connection(
  51. dispatcher_fiber,
  52. create_remote_object,
  53. PipeTransport(self._loop),
  54. self._loop,
  55. )
  56. g_self = greenlet.getcurrent()
  57. def callback_wrapper(channel_owner: ChannelOwner) -> None:
  58. playwright_impl = cast(Playwright, channel_owner)
  59. self._playwright = SyncPlaywright(playwright_impl)
  60. g_self.switch()
  61. # Switch control to the dispatcher, it'll fire an event and pass control to
  62. # the calling greenlet.
  63. self._connection.call_on_object_with_known_name("Playwright", callback_wrapper)
  64. dispatcher_fiber.switch()
  65. playwright = self._playwright
  66. playwright.stop = self.__exit__ # type: ignore
  67. return playwright
  68. def start(self) -> SyncPlaywright:
  69. return self.__enter__()
  70. def __exit__(self, *args: Any) -> None:
  71. if self._exit_was_called:
  72. return
  73. self._exit_was_called = True
  74. self._connection.stop_sync()
  75. if self._watcher:
  76. self._watcher.close()
  77. if self._own_loop:
  78. tasks = asyncio.all_tasks(self._loop)
  79. for t in [t for t in tasks if not (t.done() or t.cancelled())]:
  80. t.cancel()
  81. self._loop.run_until_complete(self._loop.shutdown_asyncgens())
  82. self._loop.close()