_async_base.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from contextlib import AbstractAsyncContextManager
  16. from types import TracebackType
  17. from typing import Any, Callable, Generic, Optional, Type, TypeVar, Union
  18. from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper
  19. mapping = ImplToApiMapping()
  20. T = TypeVar("T")
  21. Self = TypeVar("Self", bound="AsyncContextManager")
  22. class AsyncEventInfo(Generic[T]):
  23. def __init__(self, future: "asyncio.Future[T]") -> None:
  24. self._future = future
  25. @property
  26. async def value(self) -> T:
  27. return mapping.from_maybe_impl(await self._future)
  28. def _cancel(self) -> None:
  29. self._future.cancel()
  30. def is_done(self) -> bool:
  31. return self._future.done()
  32. class AsyncEventContextManager(Generic[T], AbstractAsyncContextManager):
  33. def __init__(self, future: "asyncio.Future[T]") -> None:
  34. self._event = AsyncEventInfo[T](future)
  35. async def __aenter__(self) -> AsyncEventInfo[T]:
  36. return self._event
  37. async def __aexit__(
  38. self,
  39. exc_type: Optional[Type[BaseException]],
  40. exc_val: Optional[BaseException],
  41. exc_tb: Optional[TracebackType],
  42. ) -> None:
  43. if exc_val:
  44. self._event._cancel()
  45. else:
  46. await self._event.value
  47. class AsyncBase(ImplWrapper):
  48. def __init__(self, impl_obj: Any) -> None:
  49. super().__init__(impl_obj)
  50. self._loop = impl_obj._loop
  51. def __str__(self) -> str:
  52. return self._impl_obj.__str__()
  53. def _wrap_handler(
  54. self, handler: Union[Callable[..., Any], Any]
  55. ) -> Callable[..., None]:
  56. if callable(handler):
  57. return mapping.wrap_handler(handler)
  58. return handler
  59. def on(self, event: Any, f: Any) -> None:
  60. """Registers the function ``f`` to the event name ``event``."""
  61. self._impl_obj.on(event, self._wrap_handler(f))
  62. def once(self, event: Any, f: Any) -> None:
  63. """The same as ``self.on``, except that the listener is automatically
  64. removed after being called.
  65. """
  66. self._impl_obj.once(event, self._wrap_handler(f))
  67. def remove_listener(self, event: Any, f: Any) -> None:
  68. """Removes the function ``f`` from ``event``."""
  69. self._impl_obj.remove_listener(event, self._wrap_handler(f))
  70. class AsyncContextManager(AsyncBase):
  71. async def __aenter__(self: Self) -> Self:
  72. return self
  73. async def __aexit__(
  74. self,
  75. exc_type: Optional[Type[BaseException]],
  76. exc_val: Optional[BaseException],
  77. traceback: Optional[TracebackType],
  78. ) -> None:
  79. await self.close()
  80. async def close(self) -> None: ...