_network.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. # Copyright (c) Microsoft Corporation.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import base64
  16. import inspect
  17. import json
  18. import json as json_utils
  19. import mimetypes
  20. import re
  21. from collections import defaultdict
  22. from pathlib import Path
  23. from types import SimpleNamespace
  24. from typing import (
  25. TYPE_CHECKING,
  26. Any,
  27. Callable,
  28. Coroutine,
  29. Dict,
  30. List,
  31. Optional,
  32. TypedDict,
  33. Union,
  34. cast,
  35. )
  36. from urllib import parse
  37. from playwright._impl._api_structures import (
  38. ClientCertificate,
  39. Headers,
  40. HeadersArray,
  41. RemoteAddr,
  42. RequestSizes,
  43. ResourceTiming,
  44. SecurityDetails,
  45. )
  46. from playwright._impl._connection import (
  47. ChannelOwner,
  48. from_channel,
  49. from_nullable_channel,
  50. )
  51. from playwright._impl._errors import Error
  52. from playwright._impl._event_context_manager import EventContextManagerImpl
  53. from playwright._impl._helper import (
  54. URLMatch,
  55. WebSocketRouteHandlerCallback,
  56. async_readfile,
  57. locals_to_params,
  58. url_matches,
  59. )
  60. from playwright._impl._str_utils import escape_regex_flags
  61. from playwright._impl._waiter import Waiter
  62. if TYPE_CHECKING: # pragma: no cover
  63. from playwright._impl._browser_context import BrowserContext
  64. from playwright._impl._fetch import APIResponse
  65. from playwright._impl._frame import Frame
  66. from playwright._impl._page import Page, Worker
  67. class FallbackOverrideParameters(TypedDict, total=False):
  68. url: Optional[str]
  69. method: Optional[str]
  70. headers: Optional[Dict[str, str]]
  71. postData: Optional[Union[str, bytes]]
  72. class SerializedFallbackOverrides:
  73. def __init__(self) -> None:
  74. self.url: Optional[str] = None
  75. self.method: Optional[str] = None
  76. self.headers: Optional[Dict[str, str]] = None
  77. self.post_data_buffer: Optional[bytes] = None
  78. def serialize_headers(headers: Dict[str, str]) -> HeadersArray:
  79. return [
  80. {"name": name, "value": value}
  81. for name, value in headers.items()
  82. if value is not None
  83. ]
  84. async def to_client_certificates_protocol(
  85. clientCertificates: Optional[List[ClientCertificate]],
  86. ) -> Optional[List[Dict[str, str]]]:
  87. if not clientCertificates:
  88. return None
  89. out = []
  90. for clientCertificate in clientCertificates:
  91. out_record = {
  92. "origin": clientCertificate["origin"],
  93. }
  94. if passphrase := clientCertificate.get("passphrase"):
  95. out_record["passphrase"] = passphrase
  96. if pfx := clientCertificate.get("pfx"):
  97. out_record["pfx"] = base64.b64encode(pfx).decode()
  98. if pfx_path := clientCertificate.get("pfxPath"):
  99. out_record["pfx"] = base64.b64encode(
  100. await async_readfile(pfx_path)
  101. ).decode()
  102. if cert := clientCertificate.get("cert"):
  103. out_record["cert"] = base64.b64encode(cert).decode()
  104. if cert_path := clientCertificate.get("certPath"):
  105. out_record["cert"] = base64.b64encode(
  106. await async_readfile(cert_path)
  107. ).decode()
  108. if key := clientCertificate.get("key"):
  109. out_record["key"] = base64.b64encode(key).decode()
  110. if key_path := clientCertificate.get("keyPath"):
  111. out_record["key"] = base64.b64encode(
  112. await async_readfile(key_path)
  113. ).decode()
  114. out.append(out_record)
  115. return out
  116. class Request(ChannelOwner):
  117. def __init__(
  118. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  119. ) -> None:
  120. super().__init__(parent, type, guid, initializer)
  121. self._redirected_from: Optional["Request"] = from_nullable_channel(
  122. initializer.get("redirectedFrom")
  123. )
  124. self._redirected_to: Optional["Request"] = None
  125. if self._redirected_from:
  126. self._redirected_from._redirected_to = self
  127. self._failure_text: Optional[str] = None
  128. self._timing: ResourceTiming = {
  129. "startTime": 0,
  130. "domainLookupStart": -1,
  131. "domainLookupEnd": -1,
  132. "connectStart": -1,
  133. "secureConnectionStart": -1,
  134. "connectEnd": -1,
  135. "requestStart": -1,
  136. "responseStart": -1,
  137. "responseEnd": -1,
  138. }
  139. self._provisional_headers = RawHeaders(self._initializer["headers"])
  140. self._all_headers_future: Optional[asyncio.Future[RawHeaders]] = None
  141. self._fallback_overrides: SerializedFallbackOverrides = (
  142. SerializedFallbackOverrides()
  143. )
  144. def __repr__(self) -> str:
  145. return f"<Request url={self.url!r} method={self.method!r}>"
  146. def _apply_fallback_overrides(self, overrides: FallbackOverrideParameters) -> None:
  147. self._fallback_overrides.url = overrides.get(
  148. "url", self._fallback_overrides.url
  149. )
  150. self._fallback_overrides.method = overrides.get(
  151. "method", self._fallback_overrides.method
  152. )
  153. self._fallback_overrides.headers = overrides.get(
  154. "headers", self._fallback_overrides.headers
  155. )
  156. post_data = overrides.get("postData")
  157. if isinstance(post_data, str):
  158. self._fallback_overrides.post_data_buffer = post_data.encode()
  159. elif isinstance(post_data, bytes):
  160. self._fallback_overrides.post_data_buffer = post_data
  161. elif post_data is not None:
  162. self._fallback_overrides.post_data_buffer = json.dumps(post_data).encode()
  163. @property
  164. def url(self) -> str:
  165. return cast(str, self._fallback_overrides.url or self._initializer["url"])
  166. @property
  167. def resource_type(self) -> str:
  168. return self._initializer["resourceType"]
  169. @property
  170. def service_worker(self) -> Optional["Worker"]:
  171. return cast(
  172. Optional["Worker"],
  173. from_nullable_channel(self._initializer.get("serviceWorker")),
  174. )
  175. @property
  176. def method(self) -> str:
  177. return cast(str, self._fallback_overrides.method or self._initializer["method"])
  178. async def sizes(self) -> RequestSizes:
  179. response = await self.response()
  180. if not response:
  181. raise Error("Unable to fetch sizes for failed request")
  182. return await response._channel.send(
  183. "sizes",
  184. None,
  185. )
  186. @property
  187. def post_data(self) -> Optional[str]:
  188. data = self._fallback_overrides.post_data_buffer
  189. if data:
  190. return data.decode()
  191. base64_post_data = self._initializer.get("postData")
  192. if base64_post_data is not None:
  193. return base64.b64decode(base64_post_data).decode()
  194. return None
  195. @property
  196. def post_data_json(self) -> Optional[Any]:
  197. post_data = self.post_data
  198. if not post_data:
  199. return None
  200. content_type = self.headers["content-type"]
  201. if "application/x-www-form-urlencoded" in content_type:
  202. return dict(parse.parse_qsl(post_data))
  203. try:
  204. return json.loads(post_data)
  205. except Exception:
  206. raise Error(f"POST data is not a valid JSON object: {post_data}")
  207. @property
  208. def post_data_buffer(self) -> Optional[bytes]:
  209. if self._fallback_overrides.post_data_buffer:
  210. return self._fallback_overrides.post_data_buffer
  211. if self._initializer.get("postData"):
  212. return base64.b64decode(self._initializer["postData"])
  213. return None
  214. async def response(self) -> Optional["Response"]:
  215. return from_nullable_channel(
  216. await self._channel.send(
  217. "response",
  218. None,
  219. )
  220. )
  221. @property
  222. def frame(self) -> "Frame":
  223. if not self._initializer.get("frame"):
  224. raise Error("Service Worker requests do not have an associated frame.")
  225. frame = cast("Frame", from_channel(self._initializer["frame"]))
  226. if not frame._page:
  227. raise Error(
  228. "\n".join(
  229. [
  230. "Frame for this navigation request is not available, because the request",
  231. "was issued before the frame is created. You can check whether the request",
  232. "is a navigation request by calling isNavigationRequest() method.",
  233. ]
  234. )
  235. )
  236. return frame
  237. def is_navigation_request(self) -> bool:
  238. return self._initializer["isNavigationRequest"]
  239. @property
  240. def redirected_from(self) -> Optional["Request"]:
  241. return self._redirected_from
  242. @property
  243. def redirected_to(self) -> Optional["Request"]:
  244. return self._redirected_to
  245. @property
  246. def failure(self) -> Optional[str]:
  247. return self._failure_text
  248. @property
  249. def timing(self) -> ResourceTiming:
  250. return self._timing
  251. def _set_response_end_timing(self, response_end_timing: float) -> None:
  252. self._timing["responseEnd"] = response_end_timing
  253. if self._timing["responseStart"] == -1:
  254. self._timing["responseStart"] = response_end_timing
  255. @property
  256. def headers(self) -> Headers:
  257. override = self._fallback_overrides.headers
  258. if override:
  259. return RawHeaders._from_headers_dict_lossy(override).headers()
  260. return self._provisional_headers.headers()
  261. async def all_headers(self) -> Headers:
  262. return (await self._actual_headers()).headers()
  263. async def headers_array(self) -> HeadersArray:
  264. return (await self._actual_headers()).headers_array()
  265. async def header_value(self, name: str) -> Optional[str]:
  266. return (await self._actual_headers()).get(name)
  267. async def _actual_headers(self) -> "RawHeaders":
  268. override = self._fallback_overrides.headers
  269. if override:
  270. return RawHeaders(serialize_headers(override))
  271. if not self._all_headers_future:
  272. self._all_headers_future = asyncio.Future()
  273. headers = await self._channel.send(
  274. "rawRequestHeaders", None, is_internal=True
  275. )
  276. self._all_headers_future.set_result(RawHeaders(headers))
  277. return await self._all_headers_future
  278. def _target_closed_future(self) -> asyncio.Future:
  279. frame = cast(
  280. Optional["Frame"], from_nullable_channel(self._initializer.get("frame"))
  281. )
  282. if not frame:
  283. return asyncio.Future()
  284. page = frame._page
  285. if not page:
  286. return asyncio.Future()
  287. return page._closed_or_crashed_future
  288. def _safe_page(self) -> "Optional[Page]":
  289. frame = from_nullable_channel(self._initializer.get("frame"))
  290. if not frame:
  291. return None
  292. return cast("Frame", frame)._page
  293. class Route(ChannelOwner):
  294. def __init__(
  295. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  296. ) -> None:
  297. super().__init__(parent, type, guid, initializer)
  298. self._handling_future: Optional[asyncio.Future["bool"]] = None
  299. self._context: "BrowserContext" = cast("BrowserContext", None)
  300. self._did_throw = False
  301. def _start_handling(self) -> "asyncio.Future[bool]":
  302. self._handling_future = asyncio.Future()
  303. return self._handling_future
  304. def _report_handled(self, done: bool) -> None:
  305. chain = self._handling_future
  306. assert chain
  307. self._handling_future = None
  308. chain.set_result(done)
  309. def _check_not_handled(self) -> None:
  310. if not self._handling_future:
  311. raise Error("Route is already handled!")
  312. def __repr__(self) -> str:
  313. return f"<Route request={self.request}>"
  314. @property
  315. def request(self) -> Request:
  316. return from_channel(self._initializer["request"])
  317. async def abort(self, errorCode: str = None) -> None:
  318. await self._handle_route(
  319. lambda: self._race_with_page_close(
  320. self._channel.send(
  321. "abort",
  322. None,
  323. {
  324. "errorCode": errorCode,
  325. },
  326. )
  327. )
  328. )
  329. async def fulfill(
  330. self,
  331. status: int = None,
  332. headers: Dict[str, str] = None,
  333. body: Union[str, bytes] = None,
  334. json: Any = None,
  335. path: Union[str, Path] = None,
  336. contentType: str = None,
  337. response: "APIResponse" = None,
  338. ) -> None:
  339. await self._handle_route(
  340. lambda: self._inner_fulfill(
  341. status, headers, body, json, path, contentType, response
  342. )
  343. )
  344. async def _inner_fulfill(
  345. self,
  346. status: int = None,
  347. headers: Dict[str, str] = None,
  348. body: Union[str, bytes] = None,
  349. json: Any = None,
  350. path: Union[str, Path] = None,
  351. contentType: str = None,
  352. response: "APIResponse" = None,
  353. ) -> None:
  354. params = locals_to_params(locals())
  355. if json is not None:
  356. if body is not None:
  357. raise Error("Can specify either body or json parameters")
  358. body = json_utils.dumps(json)
  359. if response:
  360. del params["response"]
  361. params["status"] = (
  362. params["status"] if params.get("status") else response.status
  363. )
  364. params["headers"] = (
  365. params["headers"] if params.get("headers") else response.headers
  366. )
  367. from playwright._impl._fetch import APIResponse
  368. if body is None and path is None and isinstance(response, APIResponse):
  369. if response._request._connection is self._connection:
  370. params["fetchResponseUid"] = response._fetch_uid
  371. else:
  372. body = await response.body()
  373. length = 0
  374. if isinstance(body, str):
  375. params["body"] = body
  376. params["isBase64"] = False
  377. length = len(body.encode())
  378. elif isinstance(body, bytes):
  379. params["body"] = base64.b64encode(body).decode()
  380. params["isBase64"] = True
  381. length = len(body)
  382. elif path:
  383. del params["path"]
  384. file_content = Path(path).read_bytes()
  385. params["body"] = base64.b64encode(file_content).decode()
  386. params["isBase64"] = True
  387. length = len(file_content)
  388. headers = {k.lower(): str(v) for k, v in params.get("headers", {}).items()}
  389. if params.get("contentType"):
  390. headers["content-type"] = params["contentType"]
  391. elif json:
  392. headers["content-type"] = "application/json"
  393. elif path:
  394. headers["content-type"] = (
  395. mimetypes.guess_type(str(Path(path)))[0] or "application/octet-stream"
  396. )
  397. if length and "content-length" not in headers:
  398. headers["content-length"] = str(length)
  399. params["headers"] = serialize_headers(headers)
  400. await self._race_with_page_close(self._channel.send("fulfill", None, params))
  401. async def _handle_route(self, callback: Callable) -> None:
  402. self._check_not_handled()
  403. try:
  404. await callback()
  405. self._report_handled(True)
  406. except Exception as e:
  407. self._did_throw = True
  408. raise e
  409. async def fetch(
  410. self,
  411. url: str = None,
  412. method: str = None,
  413. headers: Dict[str, str] = None,
  414. postData: Union[Any, str, bytes] = None,
  415. maxRedirects: int = None,
  416. maxRetries: int = None,
  417. timeout: float = None,
  418. ) -> "APIResponse":
  419. return await self._connection.wrap_api_call(
  420. lambda: self._context.request._inner_fetch(
  421. self.request,
  422. url,
  423. method,
  424. headers,
  425. postData,
  426. maxRedirects=maxRedirects,
  427. maxRetries=maxRetries,
  428. timeout=timeout,
  429. )
  430. )
  431. async def fallback(
  432. self,
  433. url: str = None,
  434. method: str = None,
  435. headers: Dict[str, str] = None,
  436. postData: Union[Any, str, bytes] = None,
  437. ) -> None:
  438. overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
  439. self._check_not_handled()
  440. self.request._apply_fallback_overrides(overrides)
  441. self._report_handled(False)
  442. async def continue_(
  443. self,
  444. url: str = None,
  445. method: str = None,
  446. headers: Dict[str, str] = None,
  447. postData: Union[Any, str, bytes] = None,
  448. ) -> None:
  449. overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
  450. async def _inner() -> None:
  451. self.request._apply_fallback_overrides(overrides)
  452. await self._inner_continue(False)
  453. return await self._handle_route(_inner)
  454. async def _inner_continue(self, is_fallback: bool = False) -> None:
  455. options = self.request._fallback_overrides
  456. await self._race_with_page_close(
  457. self._channel.send(
  458. "continue",
  459. None,
  460. {
  461. "url": options.url,
  462. "method": options.method,
  463. "headers": (
  464. serialize_headers(options.headers) if options.headers else None
  465. ),
  466. "postData": (
  467. base64.b64encode(options.post_data_buffer).decode()
  468. if options.post_data_buffer is not None
  469. else None
  470. ),
  471. "isFallback": is_fallback,
  472. },
  473. )
  474. )
  475. async def _redirected_navigation_request(self, url: str) -> None:
  476. await self._handle_route(
  477. lambda: self._race_with_page_close(
  478. self._channel.send("redirectNavigationRequest", None, {"url": url})
  479. )
  480. )
  481. async def _race_with_page_close(self, future: Coroutine) -> None:
  482. fut = asyncio.create_task(future)
  483. # Rewrite the user's stack to the new task which runs in the background.
  484. setattr(
  485. fut,
  486. "__pw_stack__",
  487. getattr(asyncio.current_task(self._loop), "__pw_stack__", inspect.stack(0)),
  488. )
  489. target_closed_future = self.request._target_closed_future()
  490. await asyncio.wait(
  491. [fut, target_closed_future],
  492. return_when=asyncio.FIRST_COMPLETED,
  493. )
  494. if fut.done() and fut.exception():
  495. raise cast(BaseException, fut.exception())
  496. if target_closed_future.done():
  497. await asyncio.gather(fut, return_exceptions=True)
  498. def _create_task_and_ignore_exception(
  499. loop: asyncio.AbstractEventLoop, coro: Coroutine
  500. ) -> None:
  501. async def _ignore_exception() -> None:
  502. try:
  503. await coro
  504. except Exception:
  505. pass
  506. loop.create_task(_ignore_exception())
  507. class ServerWebSocketRoute:
  508. def __init__(self, ws: "WebSocketRoute"):
  509. self._ws = ws
  510. def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None:
  511. self._ws._on_server_message = handler
  512. def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None:
  513. self._ws._on_server_close = handler
  514. def connect_to_server(self) -> None:
  515. raise NotImplementedError(
  516. "connectToServer must be called on the page-side WebSocketRoute"
  517. )
  518. @property
  519. def url(self) -> str:
  520. return self._ws._initializer["url"]
  521. def close(self, code: int = None, reason: str = None) -> None:
  522. _create_task_and_ignore_exception(
  523. self._ws._loop,
  524. self._ws._channel.send(
  525. "closeServer",
  526. None,
  527. {
  528. "code": code,
  529. "reason": reason,
  530. "wasClean": True,
  531. },
  532. ),
  533. )
  534. def send(self, message: Union[str, bytes]) -> None:
  535. if isinstance(message, str):
  536. _create_task_and_ignore_exception(
  537. self._ws._loop,
  538. self._ws._channel.send(
  539. "sendToServer", None, {"message": message, "isBase64": False}
  540. ),
  541. )
  542. else:
  543. _create_task_and_ignore_exception(
  544. self._ws._loop,
  545. self._ws._channel.send(
  546. "sendToServer",
  547. None,
  548. {"message": base64.b64encode(message).decode(), "isBase64": True},
  549. ),
  550. )
  551. class WebSocketRoute(ChannelOwner):
  552. def __init__(
  553. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  554. ) -> None:
  555. super().__init__(parent, type, guid, initializer)
  556. self._on_page_message: Optional[Callable[[Union[str, bytes]], Any]] = None
  557. self._on_page_close: Optional[Callable[[Optional[int], Optional[str]], Any]] = (
  558. None
  559. )
  560. self._on_server_message: Optional[Callable[[Union[str, bytes]], Any]] = None
  561. self._on_server_close: Optional[
  562. Callable[[Optional[int], Optional[str]], Any]
  563. ] = None
  564. self._server = ServerWebSocketRoute(self)
  565. self._connected = False
  566. self._channel.on("messageFromPage", self._channel_message_from_page)
  567. self._channel.on("messageFromServer", self._channel_message_from_server)
  568. self._channel.on("closePage", self._channel_close_page)
  569. self._channel.on("closeServer", self._channel_close_server)
  570. def _channel_message_from_page(self, event: Dict) -> None:
  571. if self._on_page_message:
  572. self._on_page_message(
  573. base64.b64decode(event["message"])
  574. if event["isBase64"]
  575. else event["message"]
  576. )
  577. elif self._connected:
  578. _create_task_and_ignore_exception(
  579. self._loop, self._channel.send("sendToServer", None, event)
  580. )
  581. def _channel_message_from_server(self, event: Dict) -> None:
  582. if self._on_server_message:
  583. self._on_server_message(
  584. base64.b64decode(event["message"])
  585. if event["isBase64"]
  586. else event["message"]
  587. )
  588. else:
  589. _create_task_and_ignore_exception(
  590. self._loop, self._channel.send("sendToPage", None, event)
  591. )
  592. def _channel_close_page(self, event: Dict) -> None:
  593. if self._on_page_close:
  594. self._on_page_close(event["code"], event["reason"])
  595. else:
  596. _create_task_and_ignore_exception(
  597. self._loop, self._channel.send("closeServer", None, event)
  598. )
  599. def _channel_close_server(self, event: Dict) -> None:
  600. if self._on_server_close:
  601. self._on_server_close(event["code"], event["reason"])
  602. else:
  603. _create_task_and_ignore_exception(
  604. self._loop, self._channel.send("closePage", None, event)
  605. )
  606. @property
  607. def url(self) -> str:
  608. return self._initializer["url"]
  609. async def close(self, code: int = None, reason: str = None) -> None:
  610. try:
  611. await self._channel.send(
  612. "closePage", None, {"code": code, "reason": reason, "wasClean": True}
  613. )
  614. except Exception:
  615. pass
  616. def connect_to_server(self) -> "WebSocketRoute":
  617. if self._connected:
  618. raise Error("Already connected to the server")
  619. self._connected = True
  620. asyncio.create_task(
  621. self._channel.send(
  622. "connect",
  623. None,
  624. )
  625. )
  626. return cast("WebSocketRoute", self._server)
  627. def send(self, message: Union[str, bytes]) -> None:
  628. if isinstance(message, str):
  629. _create_task_and_ignore_exception(
  630. self._loop,
  631. self._channel.send(
  632. "sendToPage", None, {"message": message, "isBase64": False}
  633. ),
  634. )
  635. else:
  636. _create_task_and_ignore_exception(
  637. self._loop,
  638. self._channel.send(
  639. "sendToPage",
  640. None,
  641. {
  642. "message": base64.b64encode(message).decode(),
  643. "isBase64": True,
  644. },
  645. ),
  646. )
  647. def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None:
  648. self._on_page_message = handler
  649. def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None:
  650. self._on_page_close = handler
  651. async def _after_handle(self) -> None:
  652. if self._connected:
  653. return
  654. # Ensure that websocket is "open" and can send messages without an actual server connection.
  655. try:
  656. await self._channel.send(
  657. "ensureOpened",
  658. None,
  659. )
  660. except Exception:
  661. pass
  662. class WebSocketRouteHandler:
  663. def __init__(
  664. self,
  665. base_url: Optional[str],
  666. url: URLMatch,
  667. handler: WebSocketRouteHandlerCallback,
  668. ):
  669. self._base_url = base_url
  670. self.url = url
  671. self.handler = handler
  672. @staticmethod
  673. def prepare_interception_patterns(
  674. handlers: List["WebSocketRouteHandler"],
  675. ) -> List[dict]:
  676. patterns = []
  677. all_urls = False
  678. for handler in handlers:
  679. if isinstance(handler.url, str):
  680. patterns.append({"glob": handler.url})
  681. elif isinstance(handler.url, re.Pattern):
  682. patterns.append(
  683. {
  684. "regexSource": handler.url.pattern,
  685. "regexFlags": escape_regex_flags(handler.url),
  686. }
  687. )
  688. else:
  689. all_urls = True
  690. if all_urls:
  691. return [{"glob": "**/*"}]
  692. return patterns
  693. def matches(self, ws_url: str) -> bool:
  694. return url_matches(self._base_url, ws_url, self.url, True)
  695. async def handle(self, websocket_route: "WebSocketRoute") -> None:
  696. coro_or_future = self.handler(websocket_route)
  697. if asyncio.iscoroutine(coro_or_future):
  698. await coro_or_future
  699. await websocket_route._after_handle()
  700. class Response(ChannelOwner):
  701. def __init__(
  702. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  703. ) -> None:
  704. super().__init__(parent, type, guid, initializer)
  705. self._request: Request = from_channel(self._initializer["request"])
  706. timing = self._initializer["timing"]
  707. self._request._timing["startTime"] = timing["startTime"]
  708. self._request._timing["domainLookupStart"] = timing["domainLookupStart"]
  709. self._request._timing["domainLookupEnd"] = timing["domainLookupEnd"]
  710. self._request._timing["connectStart"] = timing["connectStart"]
  711. self._request._timing["secureConnectionStart"] = timing["secureConnectionStart"]
  712. self._request._timing["connectEnd"] = timing["connectEnd"]
  713. self._request._timing["requestStart"] = timing["requestStart"]
  714. self._request._timing["responseStart"] = timing["responseStart"]
  715. self._provisional_headers = RawHeaders(
  716. cast(HeadersArray, self._initializer["headers"])
  717. )
  718. self._raw_headers_future: Optional[asyncio.Future[RawHeaders]] = None
  719. self._finished_future: asyncio.Future[bool] = asyncio.Future()
  720. def __repr__(self) -> str:
  721. return f"<Response url={self.url!r} request={self.request}>"
  722. @property
  723. def url(self) -> str:
  724. return self._initializer["url"]
  725. @property
  726. def ok(self) -> bool:
  727. # Status 0 is for file:// URLs
  728. return self._initializer["status"] == 0 or (
  729. self._initializer["status"] >= 200 and self._initializer["status"] <= 299
  730. )
  731. @property
  732. def status(self) -> int:
  733. return self._initializer["status"]
  734. @property
  735. def status_text(self) -> str:
  736. return self._initializer["statusText"]
  737. @property
  738. def headers(self) -> Headers:
  739. return self._provisional_headers.headers()
  740. @property
  741. def from_service_worker(self) -> bool:
  742. return self._initializer["fromServiceWorker"]
  743. async def all_headers(self) -> Headers:
  744. return (await self._actual_headers()).headers()
  745. async def headers_array(self) -> HeadersArray:
  746. return (await self._actual_headers()).headers_array()
  747. async def header_value(self, name: str) -> Optional[str]:
  748. return (await self._actual_headers()).get(name)
  749. async def header_values(self, name: str) -> List[str]:
  750. return (await self._actual_headers()).get_all(name)
  751. async def _actual_headers(self) -> "RawHeaders":
  752. if not self._raw_headers_future:
  753. self._raw_headers_future = asyncio.Future()
  754. headers = cast(
  755. HeadersArray,
  756. await self._channel.send(
  757. "rawResponseHeaders",
  758. None,
  759. ),
  760. )
  761. self._raw_headers_future.set_result(RawHeaders(headers))
  762. return await self._raw_headers_future
  763. async def server_addr(self) -> Optional[RemoteAddr]:
  764. return await self._channel.send(
  765. "serverAddr",
  766. None,
  767. )
  768. async def security_details(self) -> Optional[SecurityDetails]:
  769. return await self._channel.send(
  770. "securityDetails",
  771. None,
  772. )
  773. async def finished(self) -> None:
  774. async def on_finished() -> None:
  775. await self._request._target_closed_future()
  776. raise Error("Target closed")
  777. on_finished_task = asyncio.create_task(on_finished())
  778. await asyncio.wait(
  779. cast(
  780. List[Union[asyncio.Task, asyncio.Future]],
  781. [self._finished_future, on_finished_task],
  782. ),
  783. return_when=asyncio.FIRST_COMPLETED,
  784. )
  785. if on_finished_task.done():
  786. await on_finished_task
  787. async def body(self) -> bytes:
  788. binary = await self._channel.send(
  789. "body",
  790. None,
  791. )
  792. return base64.b64decode(binary)
  793. async def text(self) -> str:
  794. content = await self.body()
  795. return content.decode()
  796. async def json(self) -> Any:
  797. return json.loads(await self.text())
  798. @property
  799. def request(self) -> Request:
  800. return self._request
  801. @property
  802. def frame(self) -> "Frame":
  803. return self._request.frame
  804. class WebSocket(ChannelOwner):
  805. Events = SimpleNamespace(
  806. Close="close",
  807. FrameReceived="framereceived",
  808. FrameSent="framesent",
  809. Error="socketerror",
  810. )
  811. def __init__(
  812. self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
  813. ) -> None:
  814. super().__init__(parent, type, guid, initializer)
  815. self._is_closed = False
  816. self._page = cast("Page", parent)
  817. self._channel.on(
  818. "frameSent",
  819. lambda params: self._on_frame_sent(params["opcode"], params["data"]),
  820. )
  821. self._channel.on(
  822. "frameReceived",
  823. lambda params: self._on_frame_received(params["opcode"], params["data"]),
  824. )
  825. self._channel.on(
  826. "socketError",
  827. lambda params: self.emit(WebSocket.Events.Error, params["error"]),
  828. )
  829. self._channel.on("close", lambda params: self._on_close())
  830. def __repr__(self) -> str:
  831. return f"<WebSocket url={self.url!r}>"
  832. @property
  833. def url(self) -> str:
  834. return self._initializer["url"]
  835. def expect_event(
  836. self,
  837. event: str,
  838. predicate: Callable = None,
  839. timeout: float = None,
  840. ) -> EventContextManagerImpl:
  841. if timeout is None:
  842. timeout = cast(Any, self._parent)._timeout_settings.timeout()
  843. waiter = Waiter(self, f"web_socket.expect_event({event})")
  844. waiter.reject_on_timeout(
  845. cast(float, timeout),
  846. f'Timeout {timeout}ms exceeded while waiting for event "{event}"',
  847. )
  848. if event != WebSocket.Events.Close:
  849. waiter.reject_on_event(self, WebSocket.Events.Close, Error("Socket closed"))
  850. if event != WebSocket.Events.Error:
  851. waiter.reject_on_event(self, WebSocket.Events.Error, Error("Socket error"))
  852. waiter.reject_on_event(
  853. self._page, "close", lambda: self._page._close_error_with_reason()
  854. )
  855. waiter.wait_for_event(self, event, predicate)
  856. return EventContextManagerImpl(waiter.result())
  857. async def wait_for_event(
  858. self, event: str, predicate: Callable = None, timeout: float = None
  859. ) -> Any:
  860. async with self.expect_event(event, predicate, timeout) as event_info:
  861. pass
  862. return await event_info
  863. def _on_frame_sent(self, opcode: int, data: str) -> None:
  864. if opcode == 2:
  865. self.emit(WebSocket.Events.FrameSent, base64.b64decode(data))
  866. elif opcode == 1:
  867. self.emit(WebSocket.Events.FrameSent, data)
  868. def _on_frame_received(self, opcode: int, data: str) -> None:
  869. if opcode == 2:
  870. self.emit(WebSocket.Events.FrameReceived, base64.b64decode(data))
  871. elif opcode == 1:
  872. self.emit(WebSocket.Events.FrameReceived, data)
  873. def is_closed(self) -> bool:
  874. return self._is_closed
  875. def _on_close(self) -> None:
  876. self._is_closed = True
  877. self.emit(WebSocket.Events.Close, self)
  878. class RawHeaders:
  879. def __init__(self, headers: HeadersArray) -> None:
  880. self._headers_array = headers
  881. self._headers_map: Dict[str, Dict[str, bool]] = defaultdict(dict)
  882. for header in headers:
  883. self._headers_map[header["name"].lower()][header["value"]] = True
  884. @staticmethod
  885. def _from_headers_dict_lossy(headers: Dict[str, str]) -> "RawHeaders":
  886. return RawHeaders(serialize_headers(headers))
  887. def get(self, name: str) -> Optional[str]:
  888. values = self.get_all(name)
  889. if not values:
  890. return None
  891. separator = "\n" if name.lower() == "set-cookie" else ", "
  892. return separator.join(values)
  893. def get_all(self, name: str) -> List[str]:
  894. return list(self._headers_map[name.lower()].keys())
  895. def headers(self) -> Dict[str, str]:
  896. result = {}
  897. for name in self._headers_map.keys():
  898. result[name] = cast(str, self.get(name))
  899. return result
  900. def headers_array(self) -> HeadersArray:
  901. return self._headers_array