_helper.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import math
  16. import os
  17. import re
  18. import time
  19. import traceback
  20. from pathlib import Path
  21. from types import TracebackType
  22. from typing import (
  23. TYPE_CHECKING,
  24. Any,
  25. Callable,
  26. Dict,
  27. List,
  28. Literal,
  29. Optional,
  30. Pattern,
  31. Set,
  32. Tuple,
  33. TypedDict,
  34. TypeVar,
  35. Union,
  36. cast,
  37. )
  38. from urllib.parse import ParseResult, urljoin, urlparse, urlunparse
  39. from playwright._impl._api_structures import NameValue
  40. from playwright._impl._errors import (
  41. Error,
  42. TargetClosedError,
  43. TimeoutError,
  44. is_target_closed_error,
  45. rewrite_error,
  46. )
  47. from playwright._impl._glob import glob_to_regex_pattern
  48. from playwright._impl._greenlets import RouteGreenlet
  49. from playwright._impl._str_utils import escape_regex_flags
  50. if TYPE_CHECKING: # pragma: no cover
  51. from playwright._impl._api_structures import HeadersArray
  52. from playwright._impl._network import Request, Response, Route, WebSocketRoute
  53. URLMatch = Union[str, Pattern[str], Callable[[str], bool]]
  54. URLMatchRequest = Union[str, Pattern[str], Callable[["Request"], bool]]
  55. URLMatchResponse = Union[str, Pattern[str], Callable[["Response"], bool]]
  56. RouteHandlerCallback = Union[
  57. Callable[["Route"], Any], Callable[["Route", "Request"], Any]
  58. ]
  59. WebSocketRouteHandlerCallback = Callable[["WebSocketRoute"], Any]
  60. ColorScheme = Literal["dark", "light", "no-preference", "null"]
  61. ForcedColors = Literal["active", "none", "null"]
  62. Contrast = Literal["more", "no-preference", "null"]
  63. ReducedMotion = Literal["no-preference", "null", "reduce"]
  64. DocumentLoadState = Literal["commit", "domcontentloaded", "load", "networkidle"]
  65. KeyboardModifier = Literal["Alt", "Control", "ControlOrMeta", "Meta", "Shift"]
  66. MouseButton = Literal["left", "middle", "right"]
  67. ServiceWorkersPolicy = Literal["allow", "block"]
  68. HarMode = Literal["full", "minimal"]
  69. HarContentPolicy = Literal["attach", "embed", "omit"]
  70. RouteFromHarNotFoundPolicy = Literal["abort", "fallback"]
  71. class ErrorPayload(TypedDict, total=False):
  72. message: str
  73. name: str
  74. stack: str
  75. value: Optional[Any]
  76. class HarRecordingMetadata(TypedDict, total=False):
  77. path: str
  78. content: Optional[HarContentPolicy]
  79. def prepare_record_har_options(params: Dict) -> Dict[str, Any]:
  80. out_params: Dict[str, Any] = {"path": str(params["recordHarPath"])}
  81. if "recordHarUrlFilter" in params:
  82. opt = params["recordHarUrlFilter"]
  83. if isinstance(opt, str):
  84. out_params["urlGlob"] = opt
  85. if isinstance(opt, Pattern):
  86. out_params["urlRegexSource"] = opt.pattern
  87. out_params["urlRegexFlags"] = escape_regex_flags(opt)
  88. del params["recordHarUrlFilter"]
  89. if "recordHarMode" in params:
  90. out_params["mode"] = params["recordHarMode"]
  91. del params["recordHarMode"]
  92. new_content_api = None
  93. old_content_api = None
  94. if "recordHarContent" in params:
  95. new_content_api = params["recordHarContent"]
  96. del params["recordHarContent"]
  97. if "recordHarOmitContent" in params:
  98. old_content_api = params["recordHarOmitContent"]
  99. del params["recordHarOmitContent"]
  100. content = new_content_api or ("omit" if old_content_api else None)
  101. if content:
  102. out_params["content"] = content
  103. return out_params
  104. class ParsedMessageParams(TypedDict):
  105. type: str
  106. guid: str
  107. initializer: Dict
  108. class ParsedMessagePayload(TypedDict, total=False):
  109. id: int
  110. guid: str
  111. method: str
  112. params: ParsedMessageParams
  113. result: Any
  114. error: ErrorPayload
  115. class Document(TypedDict):
  116. request: Optional[Any]
  117. class FrameNavigatedEvent(TypedDict):
  118. url: str
  119. name: str
  120. newDocument: Optional[Document]
  121. error: Optional[str]
  122. Env = Dict[str, Union[str, float, bool]]
  123. def url_matches(
  124. base_url: Optional[str],
  125. url_string: str,
  126. match: Optional[URLMatch],
  127. websocket_url: bool = None,
  128. ) -> bool:
  129. if not match:
  130. return True
  131. if isinstance(match, str):
  132. match = re.compile(
  133. resolve_glob_to_regex_pattern(base_url, match, websocket_url)
  134. )
  135. if isinstance(match, Pattern):
  136. return bool(match.search(url_string))
  137. return match(url_string)
  138. def resolve_glob_to_regex_pattern(
  139. base_url: Optional[str], glob: str, websocket_url: bool = None
  140. ) -> str:
  141. if websocket_url:
  142. base_url = to_websocket_base_url(base_url)
  143. glob = resolve_glob_base(base_url, glob)
  144. return glob_to_regex_pattern(glob)
  145. def to_websocket_base_url(base_url: Optional[str]) -> Optional[str]:
  146. if base_url is not None and re.match(r"^https?://", base_url):
  147. base_url = re.sub(r"^http", "ws", base_url)
  148. return base_url
  149. def resolve_glob_base(base_url: Optional[str], match: str) -> str:
  150. if match[0] == "*":
  151. return match
  152. token_map: Dict[str, str] = {}
  153. def map_token(original: str, replacement: str) -> str:
  154. if len(original) == 0:
  155. return ""
  156. token_map[replacement] = original
  157. return replacement
  158. # Escaped `\\?` behaves the same as `?` in our glob patterns.
  159. match = match.replace(r"\\?", "?")
  160. # Special case about: URLs as they are not relative to base_url
  161. if (
  162. match.startswith("about:")
  163. or match.startswith("data:")
  164. or match.startswith("chrome:")
  165. or match.startswith("edge:")
  166. or match.startswith("file:")
  167. ):
  168. # about: and data: URLs are not relative to base_url, so we return them as is.
  169. return match
  170. # Glob symbols may be escaped in the URL and some of them such as ? affect resolution,
  171. # so we replace them with safe components first.
  172. processed_parts = []
  173. for index, token in enumerate(match.split("/")):
  174. if token in (".", "..", ""):
  175. processed_parts.append(token)
  176. continue
  177. # Handle special case of http*://, note that the new schema has to be
  178. # a web schema so that slashes are properly inserted after domain.
  179. if index == 0 and token.endswith(":"):
  180. # Replace any pattern with http:
  181. if "*" in token or "{" in token:
  182. processed_parts.append(map_token(token, "http:"))
  183. else:
  184. # Preserve explicit schema as is as it may affect trailing slashes after domain.
  185. processed_parts.append(token)
  186. continue
  187. question_index = token.find("?")
  188. if question_index == -1:
  189. processed_parts.append(map_token(token, f"$_{index}_$"))
  190. else:
  191. new_prefix = map_token(token[:question_index], f"$_{index}_$")
  192. new_suffix = map_token(token[question_index:], f"?$_{index}_$")
  193. processed_parts.append(new_prefix + new_suffix)
  194. relative_path = "/".join(processed_parts)
  195. resolved, case_insensitive_part = resolve_base_url(base_url, relative_path)
  196. for token, original in token_map.items():
  197. normalize = case_insensitive_part and token in case_insensitive_part
  198. resolved = resolved.replace(
  199. token, original.lower() if normalize else original, 1
  200. )
  201. return resolved
  202. def resolve_base_url(
  203. base_url: Optional[str], given_url: str
  204. ) -> Tuple[str, Optional[str]]:
  205. try:
  206. url = nodelike_urlparse(
  207. urljoin(base_url if base_url is not None else "", given_url)
  208. )
  209. resolved = urlunparse(url)
  210. # Schema and domain are case-insensitive.
  211. hostname_port = (
  212. url.hostname or ""
  213. ) # can't use parsed.netloc because it includes userinfo (username:password)
  214. if url.port:
  215. hostname_port += f":{url.port}"
  216. case_insensitive_prefix = f"{url.scheme}://{hostname_port}"
  217. return resolved, case_insensitive_prefix
  218. except Exception:
  219. return given_url, None
  220. def nodelike_urlparse(url: str) -> ParseResult:
  221. parsed = urlparse(url, allow_fragments=True)
  222. # https://url.spec.whatwg.org/#special-scheme
  223. is_special_url = parsed.scheme in ["http", "https", "ws", "wss", "ftp", "file"]
  224. if is_special_url:
  225. # special urls have a list path, list paths are serialized as follows: https://url.spec.whatwg.org/#url-path-serializer
  226. # urllib diverges, so we patch it here
  227. if parsed.path == "":
  228. parsed = parsed._replace(path="/")
  229. return parsed
  230. class HarLookupResult(TypedDict, total=False):
  231. action: Literal["error", "redirect", "fulfill", "noentry"]
  232. message: Optional[str]
  233. redirectURL: Optional[str]
  234. status: Optional[int]
  235. headers: Optional["HeadersArray"]
  236. body: Optional[str]
  237. DEFAULT_PLAYWRIGHT_TIMEOUT_IN_MILLISECONDS = 30000
  238. DEFAULT_PLAYWRIGHT_LAUNCH_TIMEOUT_IN_MILLISECONDS = 180000
  239. PLAYWRIGHT_MAX_DEADLINE = 2147483647 # 2^31-1
  240. class TimeoutSettings:
  241. @staticmethod
  242. def launch_timeout(timeout: Optional[float] = None) -> float:
  243. return (
  244. timeout
  245. if timeout is not None
  246. else DEFAULT_PLAYWRIGHT_LAUNCH_TIMEOUT_IN_MILLISECONDS
  247. )
  248. def __init__(self, parent: Optional["TimeoutSettings"]) -> None:
  249. self._parent = parent
  250. self._default_timeout: Optional[float] = None
  251. self._default_navigation_timeout: Optional[float] = None
  252. def set_default_timeout(self, timeout: Optional[float]) -> None:
  253. self._default_timeout = timeout
  254. def timeout(self, timeout: float = None) -> float:
  255. if timeout is not None:
  256. return timeout
  257. if self._default_timeout is not None:
  258. return self._default_timeout
  259. if self._parent:
  260. return self._parent.timeout()
  261. return DEFAULT_PLAYWRIGHT_TIMEOUT_IN_MILLISECONDS
  262. def set_default_navigation_timeout(
  263. self, navigation_timeout: Optional[float]
  264. ) -> None:
  265. self._default_navigation_timeout = navigation_timeout
  266. def default_navigation_timeout(self) -> Optional[float]:
  267. return self._default_navigation_timeout
  268. def default_timeout(self) -> Optional[float]:
  269. return self._default_timeout
  270. def navigation_timeout(self, timeout: float = None) -> float:
  271. if timeout is not None:
  272. return timeout
  273. if self._default_navigation_timeout is not None:
  274. return self._default_navigation_timeout
  275. if self._default_timeout is not None:
  276. return self._default_timeout
  277. if self._parent:
  278. return self._parent.navigation_timeout()
  279. return DEFAULT_PLAYWRIGHT_TIMEOUT_IN_MILLISECONDS
  280. def serialize_error(ex: Exception, tb: Optional[TracebackType]) -> ErrorPayload:
  281. return ErrorPayload(
  282. message=str(ex), name="Error", stack="".join(traceback.format_tb(tb))
  283. )
  284. def parse_error(error: ErrorPayload, log: Optional[str] = None) -> Error:
  285. base_error_class = Error
  286. if error.get("name") == "TimeoutError":
  287. base_error_class = TimeoutError
  288. if error.get("name") == "TargetClosedError":
  289. base_error_class = TargetClosedError
  290. if not log:
  291. log = ""
  292. exc = base_error_class(patch_error_message(error["message"]) + log)
  293. exc._name = error["name"]
  294. exc._stack = error["stack"]
  295. return exc
  296. def patch_error_message(message: str) -> str:
  297. match = re.match(r"(\w+)(: expected .*)", message)
  298. if match:
  299. message = to_snake_case(match.group(1)) + match.group(2)
  300. message = message.replace(
  301. "Pass { acceptDownloads: true }", "Pass 'accept_downloads=True'"
  302. )
  303. return message
  304. def locals_to_params(args: Dict) -> Dict:
  305. copy = {}
  306. for key in args:
  307. if key == "self":
  308. continue
  309. if args[key] is not None:
  310. copy[key] = (
  311. args[key]
  312. if not isinstance(args[key], Dict)
  313. else locals_to_params(args[key])
  314. )
  315. return copy
  316. def monotonic_time() -> int:
  317. return math.floor(time.monotonic() * 1000)
  318. class RouteHandlerInvocation:
  319. complete: "asyncio.Future"
  320. route: "Route"
  321. def __init__(self, complete: "asyncio.Future", route: "Route") -> None:
  322. self.complete = complete
  323. self.route = route
  324. class RouteHandler:
  325. def __init__(
  326. self,
  327. base_url: Optional[str],
  328. url: URLMatch,
  329. handler: RouteHandlerCallback,
  330. is_sync: bool,
  331. times: Optional[int] = None,
  332. ):
  333. self._base_url = base_url
  334. self.url = url
  335. self.handler = handler
  336. self._times = times if times else math.inf
  337. self._handled_count = 0
  338. self._is_sync = is_sync
  339. self._ignore_exception = False
  340. self._active_invocations: Set[RouteHandlerInvocation] = set()
  341. def matches(self, request_url: str) -> bool:
  342. return url_matches(self._base_url, request_url, self.url)
  343. async def handle(self, route: "Route") -> bool:
  344. handler_invocation = RouteHandlerInvocation(
  345. asyncio.get_running_loop().create_future(), route
  346. )
  347. self._active_invocations.add(handler_invocation)
  348. try:
  349. return await self._handle_internal(route)
  350. except Exception as e:
  351. # If the handler was stopped (without waiting for completion), we ignore all exceptions.
  352. if self._ignore_exception:
  353. return False
  354. if is_target_closed_error(e):
  355. # We are failing in the handler because the target has closed.
  356. # Give user a hint!
  357. optional_async_prefix = "await " if not self._is_sync else ""
  358. raise rewrite_error(
  359. e,
  360. f"\"{str(e)}\" while running route callback.\nConsider awaiting `{optional_async_prefix}page.unroute_all(behavior='ignoreErrors')`\nbefore the end of the test to ignore remaining routes in flight.",
  361. )
  362. raise e
  363. finally:
  364. handler_invocation.complete.set_result(None)
  365. self._active_invocations.remove(handler_invocation)
  366. async def _handle_internal(self, route: "Route") -> bool:
  367. handled_future = route._start_handling()
  368. self._handled_count += 1
  369. if self._is_sync:
  370. handler_finished_future = route._loop.create_future()
  371. def _handler() -> None:
  372. try:
  373. self.handler(route, route.request) # type: ignore
  374. handler_finished_future.set_result(None)
  375. except Exception as e:
  376. handler_finished_future.set_exception(e)
  377. # As with event handlers, each route handler is a potentially blocking context
  378. # so it needs a fiber.
  379. g = RouteGreenlet(_handler)
  380. g.switch()
  381. await handler_finished_future
  382. else:
  383. coro_or_future = self.handler(route, route.request) # type: ignore
  384. if coro_or_future:
  385. # separate task so that we get a proper stack trace for exceptions / tracing api_name extraction
  386. await asyncio.ensure_future(coro_or_future)
  387. return await handled_future
  388. async def stop(self, behavior: Literal["ignoreErrors", "wait"]) -> None:
  389. # When a handler is manually unrouted or its page/context is closed we either
  390. # - wait for the current handler invocations to finish
  391. # - or do not wait, if the user opted out of it, but swallow all exceptions
  392. # that happen after the unroute/close.
  393. if behavior == "ignoreErrors":
  394. self._ignore_exception = True
  395. else:
  396. tasks = []
  397. for activation in self._active_invocations:
  398. if not activation.route._did_throw:
  399. tasks.append(activation.complete)
  400. await asyncio.gather(*tasks)
  401. @property
  402. def will_expire(self) -> bool:
  403. return self._handled_count + 1 >= self._times
  404. @staticmethod
  405. def prepare_interception_patterns(
  406. handlers: List["RouteHandler"],
  407. ) -> List[Dict[str, str]]:
  408. patterns = []
  409. all = False
  410. for handler in handlers:
  411. if isinstance(handler.url, str):
  412. patterns.append({"glob": handler.url})
  413. elif isinstance(handler.url, re.Pattern):
  414. patterns.append(
  415. {
  416. "regexSource": handler.url.pattern,
  417. "regexFlags": escape_regex_flags(handler.url),
  418. }
  419. )
  420. else:
  421. all = True
  422. if all:
  423. return [{"glob": "**/*"}]
  424. return patterns
  425. to_snake_case_regex = re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
  426. def to_snake_case(name: str) -> str:
  427. return to_snake_case_regex.sub(r"_\1", name).lower()
  428. def make_dirs_for_file(path: Union[Path, str]) -> None:
  429. if not os.path.isabs(path):
  430. path = Path.cwd() / path
  431. os.makedirs(os.path.dirname(path), exist_ok=True)
  432. async def async_writefile(file: Union[str, Path], data: Union[str, bytes]) -> None:
  433. def inner() -> None:
  434. with open(file, "w" if isinstance(data, str) else "wb") as fh:
  435. fh.write(data)
  436. loop = asyncio.get_running_loop()
  437. await loop.run_in_executor(None, inner)
  438. async def async_readfile(file: Union[str, Path]) -> bytes:
  439. def inner() -> bytes:
  440. with open(file, "rb") as fh:
  441. return fh.read()
  442. loop = asyncio.get_running_loop()
  443. return await loop.run_in_executor(None, inner)
  444. T = TypeVar("T")
  445. def to_impl(obj: T) -> T:
  446. if hasattr(obj, "_impl_obj"):
  447. return cast(Any, obj)._impl_obj
  448. return obj
  449. def object_to_array(obj: Optional[Dict]) -> Optional[List[NameValue]]:
  450. if not obj:
  451. return None
  452. result = []
  453. for key, value in obj.items():
  454. result.append(NameValue(name=key, value=str(value)))
  455. return result
  456. def is_file_payload(value: Optional[Any]) -> bool:
  457. return (
  458. isinstance(value, dict)
  459. and "name" in value
  460. and "mimeType" in value
  461. and "buffer" in value
  462. )
  463. TEXTUAL_MIME_TYPE = re.compile(
  464. r"^(text\/.*?|application\/(json|(x-)?javascript|xml.*?|ecmascript|graphql|x-www-form-urlencoded)|image\/svg(\+xml)?|application\/.*?(\+json|\+xml))(;\s*charset=.*)?$"
  465. )
  466. def is_textual_mime_type(mime_type: str) -> bool:
  467. return bool(TEXTUAL_MIME_TYPE.match(mime_type))