_context_manager.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from typing import Any
  16. from playwright._impl._connection import Connection
  17. from playwright._impl._object_factory import create_remote_object
  18. from playwright._impl._transport import PipeTransport
  19. from playwright.async_api._generated import Playwright as AsyncPlaywright
  20. class PlaywrightContextManager:
  21. def __init__(self) -> None:
  22. self._connection: Connection
  23. self._exit_was_called = False
  24. async def __aenter__(self) -> AsyncPlaywright:
  25. loop = asyncio.get_running_loop()
  26. self._connection = Connection(
  27. None,
  28. create_remote_object,
  29. PipeTransport(loop),
  30. loop,
  31. )
  32. loop.create_task(self._connection.run())
  33. playwright_future = self._connection.playwright_future
  34. done, _ = await asyncio.wait(
  35. {self._connection._transport.on_error_future, playwright_future},
  36. return_when=asyncio.FIRST_COMPLETED,
  37. )
  38. if not playwright_future.done():
  39. playwright_future.cancel()
  40. playwright = AsyncPlaywright(next(iter(done)).result())
  41. playwright.stop = self.__aexit__ # type: ignore
  42. return playwright
  43. async def start(self) -> AsyncPlaywright:
  44. return await self.__aenter__()
  45. async def __aexit__(self, *args: Any) -> None:
  46. if self._exit_was_called:
  47. return
  48. self._exit_was_called = True
  49. await self._connection.stop_async()