_trio.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. from __future__ import annotations
  2. import array
  3. import math
  4. import os
  5. import socket
  6. import sys
  7. import types
  8. import weakref
  9. from collections.abc import (
  10. AsyncGenerator,
  11. AsyncIterator,
  12. Awaitable,
  13. Callable,
  14. Collection,
  15. Coroutine,
  16. Iterable,
  17. Sequence,
  18. )
  19. from contextlib import AbstractContextManager
  20. from dataclasses import dataclass
  21. from io import IOBase
  22. from os import PathLike
  23. from signal import Signals
  24. from socket import AddressFamily, SocketKind
  25. from types import TracebackType
  26. from typing import (
  27. IO,
  28. TYPE_CHECKING,
  29. Any,
  30. Generic,
  31. NoReturn,
  32. ParamSpec,
  33. TypeVar,
  34. cast,
  35. overload,
  36. )
  37. import trio.from_thread
  38. import trio.lowlevel
  39. from outcome import Error, Outcome, Value
  40. from trio.lowlevel import (
  41. current_root_task,
  42. current_task,
  43. notify_closing,
  44. wait_readable,
  45. wait_writable,
  46. )
  47. from trio.socket import SocketType as TrioSocketType
  48. from trio.to_thread import run_sync
  49. from .. import (
  50. CapacityLimiterStatistics,
  51. EventStatistics,
  52. LockStatistics,
  53. RunFinishedError,
  54. TaskInfo,
  55. WouldBlock,
  56. abc,
  57. )
  58. from .._core._eventloop import claim_worker_thread
  59. from .._core._exceptions import (
  60. BrokenResourceError,
  61. BusyResourceError,
  62. ClosedResourceError,
  63. EndOfStream,
  64. )
  65. from .._core._sockets import convert_ipv6_sockaddr
  66. from .._core._streams import create_memory_object_stream
  67. from .._core._synchronization import (
  68. CapacityLimiter as BaseCapacityLimiter,
  69. )
  70. from .._core._synchronization import Event as BaseEvent
  71. from .._core._synchronization import Lock as BaseLock
  72. from .._core._synchronization import (
  73. ResourceGuard,
  74. SemaphoreStatistics,
  75. )
  76. from .._core._synchronization import Semaphore as BaseSemaphore
  77. from .._core._tasks import CancelScope as BaseCancelScope
  78. from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
  79. from ..abc._eventloop import AsyncBackend, StrOrBytesPath
  80. from ..streams.memory import MemoryObjectSendStream
  81. if TYPE_CHECKING:
  82. from _typeshed import FileDescriptorLike
  83. if sys.version_info >= (3, 11):
  84. from typing import TypeVarTuple, Unpack
  85. else:
  86. from exceptiongroup import BaseExceptionGroup
  87. from typing_extensions import TypeVarTuple, Unpack
  88. T = TypeVar("T")
  89. T_Retval = TypeVar("T_Retval")
  90. T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
  91. PosArgsT = TypeVarTuple("PosArgsT")
  92. P = ParamSpec("P")
  93. #
  94. # Event loop
  95. #
  96. RunVar = trio.lowlevel.RunVar
  97. #
  98. # Timeouts and cancellation
  99. #
  100. class CancelScope(BaseCancelScope):
  101. def __new__(
  102. cls, original: trio.CancelScope | None = None, **kwargs: object
  103. ) -> CancelScope:
  104. return object.__new__(cls)
  105. def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None:
  106. self.__original = original or trio.CancelScope(**kwargs)
  107. def __enter__(self) -> CancelScope:
  108. self.__original.__enter__()
  109. return self
  110. def __exit__(
  111. self,
  112. exc_type: type[BaseException] | None,
  113. exc_val: BaseException | None,
  114. exc_tb: TracebackType | None,
  115. ) -> bool:
  116. return self.__original.__exit__(exc_type, exc_val, exc_tb)
  117. def cancel(self, reason: str | None = None) -> None:
  118. self.__original.cancel(reason)
  119. @property
  120. def deadline(self) -> float:
  121. return self.__original.deadline
  122. @deadline.setter
  123. def deadline(self, value: float) -> None:
  124. self.__original.deadline = value
  125. @property
  126. def cancel_called(self) -> bool:
  127. return self.__original.cancel_called
  128. @property
  129. def cancelled_caught(self) -> bool:
  130. return self.__original.cancelled_caught
  131. @property
  132. def shield(self) -> bool:
  133. return self.__original.shield
  134. @shield.setter
  135. def shield(self, value: bool) -> None:
  136. self.__original.shield = value
  137. #
  138. # Task groups
  139. #
  140. class TaskGroup(abc.TaskGroup):
  141. def __init__(self) -> None:
  142. self._active = False
  143. self._nursery_manager = trio.open_nursery(strict_exception_groups=True)
  144. self.cancel_scope = None # type: ignore[assignment]
  145. async def __aenter__(self) -> TaskGroup:
  146. self._active = True
  147. self._nursery = await self._nursery_manager.__aenter__()
  148. self.cancel_scope = CancelScope(self._nursery.cancel_scope)
  149. return self
  150. async def __aexit__(
  151. self,
  152. exc_type: type[BaseException] | None,
  153. exc_val: BaseException | None,
  154. exc_tb: TracebackType | None,
  155. ) -> bool:
  156. try:
  157. # trio.Nursery.__exit__ returns bool; .open_nursery has wrong type
  158. return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[return-value]
  159. except BaseExceptionGroup as exc:
  160. if not exc.split(trio.Cancelled)[1]:
  161. raise trio.Cancelled._create() from exc
  162. raise
  163. finally:
  164. del exc_val, exc_tb
  165. self._active = False
  166. def start_soon(
  167. self,
  168. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  169. *args: Unpack[PosArgsT],
  170. name: object = None,
  171. ) -> None:
  172. if not self._active:
  173. raise RuntimeError(
  174. "This task group is not active; no new tasks can be started."
  175. )
  176. self._nursery.start_soon(func, *args, name=name)
  177. async def start(
  178. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  179. ) -> Any:
  180. if not self._active:
  181. raise RuntimeError(
  182. "This task group is not active; no new tasks can be started."
  183. )
  184. return await self._nursery.start(func, *args, name=name)
  185. #
  186. # Subprocesses
  187. #
  188. @dataclass(eq=False)
  189. class ReceiveStreamWrapper(abc.ByteReceiveStream):
  190. _stream: trio.abc.ReceiveStream
  191. async def receive(self, max_bytes: int | None = None) -> bytes:
  192. try:
  193. data = await self._stream.receive_some(max_bytes)
  194. except trio.ClosedResourceError as exc:
  195. raise ClosedResourceError from exc.__cause__
  196. except trio.BrokenResourceError as exc:
  197. raise BrokenResourceError from exc.__cause__
  198. if data:
  199. return bytes(data)
  200. else:
  201. raise EndOfStream
  202. async def aclose(self) -> None:
  203. await self._stream.aclose()
  204. @dataclass(eq=False)
  205. class SendStreamWrapper(abc.ByteSendStream):
  206. _stream: trio.abc.SendStream
  207. async def send(self, item: bytes) -> None:
  208. try:
  209. await self._stream.send_all(item)
  210. except trio.ClosedResourceError as exc:
  211. raise ClosedResourceError from exc.__cause__
  212. except trio.BrokenResourceError as exc:
  213. raise BrokenResourceError from exc.__cause__
  214. async def aclose(self) -> None:
  215. await self._stream.aclose()
  216. @dataclass(eq=False)
  217. class Process(abc.Process):
  218. _process: trio.Process
  219. _stdin: abc.ByteSendStream | None
  220. _stdout: abc.ByteReceiveStream | None
  221. _stderr: abc.ByteReceiveStream | None
  222. async def aclose(self) -> None:
  223. with CancelScope(shield=True):
  224. if self._stdin:
  225. await self._stdin.aclose()
  226. if self._stdout:
  227. await self._stdout.aclose()
  228. if self._stderr:
  229. await self._stderr.aclose()
  230. try:
  231. await self.wait()
  232. except BaseException:
  233. self.kill()
  234. with CancelScope(shield=True):
  235. await self.wait()
  236. raise
  237. async def wait(self) -> int:
  238. return await self._process.wait()
  239. def terminate(self) -> None:
  240. self._process.terminate()
  241. def kill(self) -> None:
  242. self._process.kill()
  243. def send_signal(self, signal: Signals) -> None:
  244. self._process.send_signal(signal)
  245. @property
  246. def pid(self) -> int:
  247. return self._process.pid
  248. @property
  249. def returncode(self) -> int | None:
  250. return self._process.returncode
  251. @property
  252. def stdin(self) -> abc.ByteSendStream | None:
  253. return self._stdin
  254. @property
  255. def stdout(self) -> abc.ByteReceiveStream | None:
  256. return self._stdout
  257. @property
  258. def stderr(self) -> abc.ByteReceiveStream | None:
  259. return self._stderr
  260. class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
  261. def after_run(self) -> None:
  262. super().after_run()
  263. current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar(
  264. "current_default_worker_process_limiter"
  265. )
  266. async def _shutdown_process_pool(workers: set[abc.Process]) -> None:
  267. try:
  268. await trio.sleep(math.inf)
  269. except trio.Cancelled:
  270. for process in workers:
  271. if process.returncode is None:
  272. process.kill()
  273. with CancelScope(shield=True):
  274. for process in workers:
  275. await process.aclose()
  276. #
  277. # Sockets and networking
  278. #
  279. class _TrioSocketMixin(Generic[T_SockAddr]):
  280. def __init__(self, trio_socket: TrioSocketType) -> None:
  281. self._trio_socket = trio_socket
  282. self._closed = False
  283. def _check_closed(self) -> None:
  284. if self._closed:
  285. raise ClosedResourceError
  286. if self._trio_socket.fileno() < 0:
  287. raise BrokenResourceError
  288. @property
  289. def _raw_socket(self) -> socket.socket:
  290. return self._trio_socket._sock # type: ignore[attr-defined]
  291. async def aclose(self) -> None:
  292. if self._trio_socket.fileno() >= 0:
  293. self._closed = True
  294. self._trio_socket.close()
  295. def _convert_socket_error(self, exc: BaseException) -> NoReturn:
  296. if isinstance(exc, trio.ClosedResourceError):
  297. raise ClosedResourceError from exc
  298. elif self._trio_socket.fileno() < 0 and self._closed:
  299. raise ClosedResourceError from None
  300. elif isinstance(exc, OSError):
  301. raise BrokenResourceError from exc
  302. else:
  303. raise exc
  304. class SocketStream(_TrioSocketMixin, abc.SocketStream):
  305. def __init__(self, trio_socket: TrioSocketType) -> None:
  306. super().__init__(trio_socket)
  307. self._receive_guard = ResourceGuard("reading from")
  308. self._send_guard = ResourceGuard("writing to")
  309. async def receive(self, max_bytes: int = 65536) -> bytes:
  310. with self._receive_guard:
  311. try:
  312. data = await self._trio_socket.recv(max_bytes)
  313. except BaseException as exc:
  314. self._convert_socket_error(exc)
  315. if data:
  316. return data
  317. else:
  318. raise EndOfStream
  319. async def send(self, item: bytes) -> None:
  320. with self._send_guard:
  321. view = memoryview(item)
  322. while view:
  323. try:
  324. bytes_sent = await self._trio_socket.send(view)
  325. except BaseException as exc:
  326. self._convert_socket_error(exc)
  327. view = view[bytes_sent:]
  328. async def send_eof(self) -> None:
  329. self._trio_socket.shutdown(socket.SHUT_WR)
  330. class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
  331. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  332. if not isinstance(msglen, int) or msglen < 0:
  333. raise ValueError("msglen must be a non-negative integer")
  334. if not isinstance(maxfds, int) or maxfds < 1:
  335. raise ValueError("maxfds must be a positive integer")
  336. fds = array.array("i")
  337. await trio.lowlevel.checkpoint()
  338. with self._receive_guard:
  339. while True:
  340. try:
  341. message, ancdata, flags, addr = await self._trio_socket.recvmsg(
  342. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  343. )
  344. except BaseException as exc:
  345. self._convert_socket_error(exc)
  346. else:
  347. if not message and not ancdata:
  348. raise EndOfStream
  349. break
  350. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  351. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  352. raise RuntimeError(
  353. f"Received unexpected ancillary data; message = {message!r}, "
  354. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  355. )
  356. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  357. return message, list(fds)
  358. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  359. if not message:
  360. raise ValueError("message must not be empty")
  361. if not fds:
  362. raise ValueError("fds must not be empty")
  363. filenos: list[int] = []
  364. for fd in fds:
  365. if isinstance(fd, int):
  366. filenos.append(fd)
  367. elif isinstance(fd, IOBase):
  368. filenos.append(fd.fileno())
  369. fdarray = array.array("i", filenos)
  370. await trio.lowlevel.checkpoint()
  371. with self._send_guard:
  372. while True:
  373. try:
  374. await self._trio_socket.sendmsg(
  375. [message],
  376. [
  377. (
  378. socket.SOL_SOCKET,
  379. socket.SCM_RIGHTS,
  380. fdarray,
  381. )
  382. ],
  383. )
  384. break
  385. except BaseException as exc:
  386. self._convert_socket_error(exc)
  387. class TCPSocketListener(_TrioSocketMixin, abc.SocketListener):
  388. def __init__(self, raw_socket: socket.socket):
  389. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  390. self._accept_guard = ResourceGuard("accepting connections from")
  391. async def accept(self) -> SocketStream:
  392. with self._accept_guard:
  393. try:
  394. trio_socket, _addr = await self._trio_socket.accept()
  395. except BaseException as exc:
  396. self._convert_socket_error(exc)
  397. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  398. return SocketStream(trio_socket)
  399. class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener):
  400. def __init__(self, raw_socket: socket.socket):
  401. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  402. self._accept_guard = ResourceGuard("accepting connections from")
  403. async def accept(self) -> UNIXSocketStream:
  404. with self._accept_guard:
  405. try:
  406. trio_socket, _addr = await self._trio_socket.accept()
  407. except BaseException as exc:
  408. self._convert_socket_error(exc)
  409. return UNIXSocketStream(trio_socket)
  410. class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket):
  411. def __init__(self, trio_socket: TrioSocketType) -> None:
  412. super().__init__(trio_socket)
  413. self._receive_guard = ResourceGuard("reading from")
  414. self._send_guard = ResourceGuard("writing to")
  415. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  416. with self._receive_guard:
  417. try:
  418. data, addr = await self._trio_socket.recvfrom(65536)
  419. return data, convert_ipv6_sockaddr(addr)
  420. except BaseException as exc:
  421. self._convert_socket_error(exc)
  422. async def send(self, item: UDPPacketType) -> None:
  423. with self._send_guard:
  424. try:
  425. await self._trio_socket.sendto(*item)
  426. except BaseException as exc:
  427. self._convert_socket_error(exc)
  428. class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket):
  429. def __init__(self, trio_socket: TrioSocketType) -> None:
  430. super().__init__(trio_socket)
  431. self._receive_guard = ResourceGuard("reading from")
  432. self._send_guard = ResourceGuard("writing to")
  433. async def receive(self) -> bytes:
  434. with self._receive_guard:
  435. try:
  436. return await self._trio_socket.recv(65536)
  437. except BaseException as exc:
  438. self._convert_socket_error(exc)
  439. async def send(self, item: bytes) -> None:
  440. with self._send_guard:
  441. try:
  442. await self._trio_socket.send(item)
  443. except BaseException as exc:
  444. self._convert_socket_error(exc)
  445. class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket):
  446. def __init__(self, trio_socket: TrioSocketType) -> None:
  447. super().__init__(trio_socket)
  448. self._receive_guard = ResourceGuard("reading from")
  449. self._send_guard = ResourceGuard("writing to")
  450. async def receive(self) -> UNIXDatagramPacketType:
  451. with self._receive_guard:
  452. try:
  453. data, addr = await self._trio_socket.recvfrom(65536)
  454. return data, addr
  455. except BaseException as exc:
  456. self._convert_socket_error(exc)
  457. async def send(self, item: UNIXDatagramPacketType) -> None:
  458. with self._send_guard:
  459. try:
  460. await self._trio_socket.sendto(*item)
  461. except BaseException as exc:
  462. self._convert_socket_error(exc)
  463. class ConnectedUNIXDatagramSocket(
  464. _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket
  465. ):
  466. def __init__(self, trio_socket: TrioSocketType) -> None:
  467. super().__init__(trio_socket)
  468. self._receive_guard = ResourceGuard("reading from")
  469. self._send_guard = ResourceGuard("writing to")
  470. async def receive(self) -> bytes:
  471. with self._receive_guard:
  472. try:
  473. return await self._trio_socket.recv(65536)
  474. except BaseException as exc:
  475. self._convert_socket_error(exc)
  476. async def send(self, item: bytes) -> None:
  477. with self._send_guard:
  478. try:
  479. await self._trio_socket.send(item)
  480. except BaseException as exc:
  481. self._convert_socket_error(exc)
  482. #
  483. # Synchronization
  484. #
  485. class Event(BaseEvent):
  486. def __new__(cls) -> Event:
  487. return object.__new__(cls)
  488. def __init__(self) -> None:
  489. self.__original = trio.Event()
  490. def is_set(self) -> bool:
  491. return self.__original.is_set()
  492. async def wait(self) -> None:
  493. return await self.__original.wait()
  494. def statistics(self) -> EventStatistics:
  495. orig_statistics = self.__original.statistics()
  496. return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting)
  497. def set(self) -> None:
  498. self.__original.set()
  499. class Lock(BaseLock):
  500. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  501. return object.__new__(cls)
  502. def __init__(self, *, fast_acquire: bool = False) -> None:
  503. self._fast_acquire = fast_acquire
  504. self.__original = trio.Lock()
  505. @staticmethod
  506. def _convert_runtime_error_msg(exc: RuntimeError) -> None:
  507. if exc.args == ("attempt to re-acquire an already held Lock",):
  508. exc.args = ("Attempted to acquire an already held Lock",)
  509. async def acquire(self) -> None:
  510. if not self._fast_acquire:
  511. try:
  512. await self.__original.acquire()
  513. except RuntimeError as exc:
  514. self._convert_runtime_error_msg(exc)
  515. raise
  516. return
  517. # This is the "fast path" where we don't let other tasks run
  518. await trio.lowlevel.checkpoint_if_cancelled()
  519. try:
  520. self.__original.acquire_nowait()
  521. except trio.WouldBlock:
  522. await self.__original._lot.park()
  523. except RuntimeError as exc:
  524. self._convert_runtime_error_msg(exc)
  525. raise
  526. def acquire_nowait(self) -> None:
  527. try:
  528. self.__original.acquire_nowait()
  529. except trio.WouldBlock:
  530. raise WouldBlock from None
  531. except RuntimeError as exc:
  532. self._convert_runtime_error_msg(exc)
  533. raise
  534. def locked(self) -> bool:
  535. return self.__original.locked()
  536. def release(self) -> None:
  537. self.__original.release()
  538. def statistics(self) -> LockStatistics:
  539. orig_statistics = self.__original.statistics()
  540. owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None
  541. return LockStatistics(
  542. orig_statistics.locked, owner, orig_statistics.tasks_waiting
  543. )
  544. class Semaphore(BaseSemaphore):
  545. def __new__(
  546. cls,
  547. initial_value: int,
  548. *,
  549. max_value: int | None = None,
  550. fast_acquire: bool = False,
  551. ) -> Semaphore:
  552. return object.__new__(cls)
  553. def __init__(
  554. self,
  555. initial_value: int,
  556. *,
  557. max_value: int | None = None,
  558. fast_acquire: bool = False,
  559. ) -> None:
  560. super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  561. self.__original = trio.Semaphore(initial_value, max_value=max_value)
  562. async def acquire(self) -> None:
  563. if not self._fast_acquire:
  564. await self.__original.acquire()
  565. return
  566. # This is the "fast path" where we don't let other tasks run
  567. await trio.lowlevel.checkpoint_if_cancelled()
  568. try:
  569. self.__original.acquire_nowait()
  570. except trio.WouldBlock:
  571. await self.__original._lot.park()
  572. def acquire_nowait(self) -> None:
  573. try:
  574. self.__original.acquire_nowait()
  575. except trio.WouldBlock:
  576. raise WouldBlock from None
  577. @property
  578. def max_value(self) -> int | None:
  579. return self.__original.max_value
  580. @property
  581. def value(self) -> int:
  582. return self.__original.value
  583. def release(self) -> None:
  584. self.__original.release()
  585. def statistics(self) -> SemaphoreStatistics:
  586. orig_statistics = self.__original.statistics()
  587. return SemaphoreStatistics(orig_statistics.tasks_waiting)
  588. class CapacityLimiter(BaseCapacityLimiter):
  589. def __new__(
  590. cls,
  591. total_tokens: float | None = None,
  592. *,
  593. original: trio.CapacityLimiter | None = None,
  594. ) -> CapacityLimiter:
  595. return object.__new__(cls)
  596. def __init__(
  597. self,
  598. total_tokens: float | None = None,
  599. *,
  600. original: trio.CapacityLimiter | None = None,
  601. ) -> None:
  602. if original is not None:
  603. self.__original = original
  604. else:
  605. assert total_tokens is not None
  606. self.__original = trio.CapacityLimiter(total_tokens)
  607. async def __aenter__(self) -> None:
  608. return await self.__original.__aenter__()
  609. async def __aexit__(
  610. self,
  611. exc_type: type[BaseException] | None,
  612. exc_val: BaseException | None,
  613. exc_tb: TracebackType | None,
  614. ) -> None:
  615. await self.__original.__aexit__(exc_type, exc_val, exc_tb)
  616. @property
  617. def total_tokens(self) -> float:
  618. return self.__original.total_tokens
  619. @total_tokens.setter
  620. def total_tokens(self, value: float) -> None:
  621. self.__original.total_tokens = value
  622. @property
  623. def borrowed_tokens(self) -> int:
  624. return self.__original.borrowed_tokens
  625. @property
  626. def available_tokens(self) -> float:
  627. return self.__original.available_tokens
  628. def acquire_nowait(self) -> None:
  629. self.__original.acquire_nowait()
  630. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  631. self.__original.acquire_on_behalf_of_nowait(borrower)
  632. async def acquire(self) -> None:
  633. await self.__original.acquire()
  634. async def acquire_on_behalf_of(self, borrower: object) -> None:
  635. await self.__original.acquire_on_behalf_of(borrower)
  636. def release(self) -> None:
  637. return self.__original.release()
  638. def release_on_behalf_of(self, borrower: object) -> None:
  639. return self.__original.release_on_behalf_of(borrower)
  640. def statistics(self) -> CapacityLimiterStatistics:
  641. orig = self.__original.statistics()
  642. return CapacityLimiterStatistics(
  643. borrowed_tokens=orig.borrowed_tokens,
  644. total_tokens=orig.total_tokens,
  645. borrowers=tuple(orig.borrowers),
  646. tasks_waiting=orig.tasks_waiting,
  647. )
  648. _capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper")
  649. #
  650. # Signal handling
  651. #
  652. class _SignalReceiver:
  653. _iterator: AsyncIterator[int]
  654. def __init__(self, signals: tuple[Signals, ...]):
  655. self._signals = signals
  656. def __enter__(self) -> _SignalReceiver:
  657. self._cm = trio.open_signal_receiver(*self._signals)
  658. self._iterator = self._cm.__enter__()
  659. return self
  660. def __exit__(
  661. self,
  662. exc_type: type[BaseException] | None,
  663. exc_val: BaseException | None,
  664. exc_tb: TracebackType | None,
  665. ) -> bool | None:
  666. return self._cm.__exit__(exc_type, exc_val, exc_tb)
  667. def __aiter__(self) -> _SignalReceiver:
  668. return self
  669. async def __anext__(self) -> Signals:
  670. signum = await self._iterator.__anext__()
  671. return Signals(signum)
  672. #
  673. # Testing and debugging
  674. #
  675. class TestRunner(abc.TestRunner):
  676. def __init__(self, **options: Any) -> None:
  677. from queue import Queue
  678. self._call_queue: Queue[Callable[[], object]] = Queue()
  679. self._send_stream: MemoryObjectSendStream | None = None
  680. self._options = options
  681. def __exit__(
  682. self,
  683. exc_type: type[BaseException] | None,
  684. exc_val: BaseException | None,
  685. exc_tb: types.TracebackType | None,
  686. ) -> None:
  687. if self._send_stream:
  688. self._send_stream.close()
  689. while self._send_stream is not None:
  690. self._call_queue.get()()
  691. async def _run_tests_and_fixtures(self) -> None:
  692. self._send_stream, receive_stream = create_memory_object_stream(1)
  693. with receive_stream:
  694. async for coro, outcome_holder in receive_stream:
  695. try:
  696. retval = await coro
  697. except BaseException as exc:
  698. outcome_holder.append(Error(exc))
  699. else:
  700. outcome_holder.append(Value(retval))
  701. def _main_task_finished(self, outcome: object) -> None:
  702. self._send_stream = None
  703. def _call_in_runner_task(
  704. self,
  705. func: Callable[P, Awaitable[T_Retval]],
  706. /,
  707. *args: P.args,
  708. **kwargs: P.kwargs,
  709. ) -> T_Retval:
  710. if self._send_stream is None:
  711. trio.lowlevel.start_guest_run(
  712. self._run_tests_and_fixtures,
  713. run_sync_soon_threadsafe=self._call_queue.put,
  714. done_callback=self._main_task_finished,
  715. **self._options,
  716. )
  717. while self._send_stream is None:
  718. self._call_queue.get()()
  719. outcome_holder: list[Outcome] = []
  720. self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
  721. while not outcome_holder:
  722. self._call_queue.get()()
  723. return outcome_holder[0].unwrap()
  724. def run_asyncgen_fixture(
  725. self,
  726. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  727. kwargs: dict[str, Any],
  728. ) -> Iterable[T_Retval]:
  729. asyncgen = fixture_func(**kwargs)
  730. fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None)
  731. yield fixturevalue
  732. try:
  733. self._call_in_runner_task(asyncgen.asend, None)
  734. except StopAsyncIteration:
  735. pass
  736. else:
  737. self._call_in_runner_task(asyncgen.aclose)
  738. raise RuntimeError("Async generator fixture did not stop")
  739. def run_fixture(
  740. self,
  741. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  742. kwargs: dict[str, Any],
  743. ) -> T_Retval:
  744. return self._call_in_runner_task(fixture_func, **kwargs)
  745. def run_test(
  746. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  747. ) -> None:
  748. self._call_in_runner_task(test_func, **kwargs)
  749. class TrioTaskInfo(TaskInfo):
  750. def __init__(self, task: trio.lowlevel.Task):
  751. parent_id = None
  752. if task.parent_nursery and task.parent_nursery.parent_task:
  753. parent_id = id(task.parent_nursery.parent_task)
  754. super().__init__(id(task), parent_id, task.name, task.coro)
  755. self._task = weakref.proxy(task)
  756. def has_pending_cancellation(self) -> bool:
  757. try:
  758. return self._task._cancel_status.effectively_cancelled
  759. except ReferenceError:
  760. # If the task is no longer around, it surely doesn't have a cancellation
  761. # pending
  762. return False
  763. class TrioBackend(AsyncBackend):
  764. @classmethod
  765. def run(
  766. cls,
  767. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  768. args: tuple[Unpack[PosArgsT]],
  769. kwargs: dict[str, Any],
  770. options: dict[str, Any],
  771. ) -> T_Retval:
  772. return trio.run(func, *args)
  773. @classmethod
  774. def current_token(cls) -> object:
  775. return trio.lowlevel.current_trio_token()
  776. @classmethod
  777. def current_time(cls) -> float:
  778. return trio.current_time()
  779. @classmethod
  780. def cancelled_exception_class(cls) -> type[BaseException]:
  781. return trio.Cancelled
  782. @classmethod
  783. async def checkpoint(cls) -> None:
  784. await trio.lowlevel.checkpoint()
  785. @classmethod
  786. async def checkpoint_if_cancelled(cls) -> None:
  787. await trio.lowlevel.checkpoint_if_cancelled()
  788. @classmethod
  789. async def cancel_shielded_checkpoint(cls) -> None:
  790. await trio.lowlevel.cancel_shielded_checkpoint()
  791. @classmethod
  792. async def sleep(cls, delay: float) -> None:
  793. await trio.sleep(delay)
  794. @classmethod
  795. def create_cancel_scope(
  796. cls, *, deadline: float = math.inf, shield: bool = False
  797. ) -> abc.CancelScope:
  798. return CancelScope(deadline=deadline, shield=shield)
  799. @classmethod
  800. def current_effective_deadline(cls) -> float:
  801. return trio.current_effective_deadline()
  802. @classmethod
  803. def create_task_group(cls) -> abc.TaskGroup:
  804. return TaskGroup()
  805. @classmethod
  806. def create_event(cls) -> abc.Event:
  807. return Event()
  808. @classmethod
  809. def create_lock(cls, *, fast_acquire: bool) -> Lock:
  810. return Lock(fast_acquire=fast_acquire)
  811. @classmethod
  812. def create_semaphore(
  813. cls,
  814. initial_value: int,
  815. *,
  816. max_value: int | None = None,
  817. fast_acquire: bool = False,
  818. ) -> abc.Semaphore:
  819. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  820. @classmethod
  821. def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
  822. return CapacityLimiter(total_tokens)
  823. @classmethod
  824. async def run_sync_in_worker_thread(
  825. cls,
  826. func: Callable[[Unpack[PosArgsT]], T_Retval],
  827. args: tuple[Unpack[PosArgsT]],
  828. abandon_on_cancel: bool = False,
  829. limiter: abc.CapacityLimiter | None = None,
  830. ) -> T_Retval:
  831. def wrapper() -> T_Retval:
  832. with claim_worker_thread(TrioBackend, token):
  833. return func(*args)
  834. token = TrioBackend.current_token()
  835. return await run_sync(
  836. wrapper,
  837. abandon_on_cancel=abandon_on_cancel,
  838. limiter=cast(trio.CapacityLimiter, limiter),
  839. )
  840. @classmethod
  841. def check_cancelled(cls) -> None:
  842. trio.from_thread.check_cancelled()
  843. @classmethod
  844. def run_async_from_thread(
  845. cls,
  846. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  847. args: tuple[Unpack[PosArgsT]],
  848. token: object,
  849. ) -> T_Retval:
  850. trio_token = cast("trio.lowlevel.TrioToken | None", token)
  851. try:
  852. return trio.from_thread.run(func, *args, trio_token=trio_token)
  853. except trio.RunFinishedError:
  854. raise RunFinishedError from None
  855. @classmethod
  856. def run_sync_from_thread(
  857. cls,
  858. func: Callable[[Unpack[PosArgsT]], T_Retval],
  859. args: tuple[Unpack[PosArgsT]],
  860. token: object,
  861. ) -> T_Retval:
  862. trio_token = cast("trio.lowlevel.TrioToken | None", token)
  863. try:
  864. return trio.from_thread.run_sync(func, *args, trio_token=trio_token)
  865. except trio.RunFinishedError:
  866. raise RunFinishedError from None
  867. @classmethod
  868. async def open_process(
  869. cls,
  870. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  871. *,
  872. stdin: int | IO[Any] | None,
  873. stdout: int | IO[Any] | None,
  874. stderr: int | IO[Any] | None,
  875. **kwargs: Any,
  876. ) -> Process:
  877. def convert_item(item: StrOrBytesPath) -> str:
  878. str_or_bytes = os.fspath(item)
  879. if isinstance(str_or_bytes, str):
  880. return str_or_bytes
  881. else:
  882. return os.fsdecode(str_or_bytes)
  883. if isinstance(command, (str, bytes, PathLike)):
  884. process = await trio.lowlevel.open_process(
  885. convert_item(command),
  886. stdin=stdin,
  887. stdout=stdout,
  888. stderr=stderr,
  889. shell=True,
  890. **kwargs,
  891. )
  892. else:
  893. process = await trio.lowlevel.open_process(
  894. [convert_item(item) for item in command],
  895. stdin=stdin,
  896. stdout=stdout,
  897. stderr=stderr,
  898. shell=False,
  899. **kwargs,
  900. )
  901. stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
  902. stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
  903. stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
  904. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  905. @classmethod
  906. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  907. trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
  908. @classmethod
  909. async def connect_tcp(
  910. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  911. ) -> SocketStream:
  912. family = socket.AF_INET6 if ":" in host else socket.AF_INET
  913. trio_socket = trio.socket.socket(family)
  914. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  915. if local_address:
  916. await trio_socket.bind(local_address)
  917. try:
  918. await trio_socket.connect((host, port))
  919. except BaseException:
  920. trio_socket.close()
  921. raise
  922. return SocketStream(trio_socket)
  923. @classmethod
  924. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  925. trio_socket = trio.socket.socket(socket.AF_UNIX)
  926. try:
  927. await trio_socket.connect(path)
  928. except BaseException:
  929. trio_socket.close()
  930. raise
  931. return UNIXSocketStream(trio_socket)
  932. @classmethod
  933. def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener:
  934. return TCPSocketListener(sock)
  935. @classmethod
  936. def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener:
  937. return UNIXSocketListener(sock)
  938. @classmethod
  939. async def create_udp_socket(
  940. cls,
  941. family: socket.AddressFamily,
  942. local_address: IPSockAddrType | None,
  943. remote_address: IPSockAddrType | None,
  944. reuse_port: bool,
  945. ) -> UDPSocket | ConnectedUDPSocket:
  946. trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
  947. if reuse_port:
  948. trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
  949. if local_address:
  950. await trio_socket.bind(local_address)
  951. if remote_address:
  952. await trio_socket.connect(remote_address)
  953. return ConnectedUDPSocket(trio_socket)
  954. else:
  955. return UDPSocket(trio_socket)
  956. @classmethod
  957. @overload
  958. async def create_unix_datagram_socket(
  959. cls, raw_socket: socket.socket, remote_path: None
  960. ) -> abc.UNIXDatagramSocket: ...
  961. @classmethod
  962. @overload
  963. async def create_unix_datagram_socket(
  964. cls, raw_socket: socket.socket, remote_path: str | bytes
  965. ) -> abc.ConnectedUNIXDatagramSocket: ...
  966. @classmethod
  967. async def create_unix_datagram_socket(
  968. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  969. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  970. trio_socket = trio.socket.from_stdlib_socket(raw_socket)
  971. if remote_path:
  972. await trio_socket.connect(remote_path)
  973. return ConnectedUNIXDatagramSocket(trio_socket)
  974. else:
  975. return UNIXDatagramSocket(trio_socket)
  976. @classmethod
  977. async def getaddrinfo(
  978. cls,
  979. host: bytes | str | None,
  980. port: str | int | None,
  981. *,
  982. family: int | AddressFamily = 0,
  983. type: int | SocketKind = 0,
  984. proto: int = 0,
  985. flags: int = 0,
  986. ) -> Sequence[
  987. tuple[
  988. AddressFamily,
  989. SocketKind,
  990. int,
  991. str,
  992. tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
  993. ]
  994. ]:
  995. return await trio.socket.getaddrinfo(host, port, family, type, proto, flags)
  996. @classmethod
  997. async def getnameinfo(
  998. cls, sockaddr: IPSockAddrType, flags: int = 0
  999. ) -> tuple[str, str]:
  1000. return await trio.socket.getnameinfo(sockaddr, flags)
  1001. @classmethod
  1002. async def wait_readable(cls, obj: FileDescriptorLike) -> None:
  1003. try:
  1004. await wait_readable(obj)
  1005. except trio.ClosedResourceError as exc:
  1006. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1007. except trio.BusyResourceError:
  1008. raise BusyResourceError("reading from") from None
  1009. @classmethod
  1010. async def wait_writable(cls, obj: FileDescriptorLike) -> None:
  1011. try:
  1012. await wait_writable(obj)
  1013. except trio.ClosedResourceError as exc:
  1014. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1015. except trio.BusyResourceError:
  1016. raise BusyResourceError("writing to") from None
  1017. @classmethod
  1018. def notify_closing(cls, obj: FileDescriptorLike) -> None:
  1019. notify_closing(obj)
  1020. @classmethod
  1021. async def wrap_listener_socket(cls, sock: socket.socket) -> abc.SocketListener:
  1022. return TCPSocketListener(sock)
  1023. @classmethod
  1024. async def wrap_stream_socket(cls, sock: socket.socket) -> SocketStream:
  1025. trio_sock = trio.socket.from_stdlib_socket(sock)
  1026. return SocketStream(trio_sock)
  1027. @classmethod
  1028. async def wrap_unix_stream_socket(cls, sock: socket.socket) -> UNIXSocketStream:
  1029. trio_sock = trio.socket.from_stdlib_socket(sock)
  1030. return UNIXSocketStream(trio_sock)
  1031. @classmethod
  1032. async def wrap_udp_socket(cls, sock: socket.socket) -> UDPSocket:
  1033. trio_sock = trio.socket.from_stdlib_socket(sock)
  1034. return UDPSocket(trio_sock)
  1035. @classmethod
  1036. async def wrap_connected_udp_socket(cls, sock: socket.socket) -> ConnectedUDPSocket:
  1037. trio_sock = trio.socket.from_stdlib_socket(sock)
  1038. return ConnectedUDPSocket(trio_sock)
  1039. @classmethod
  1040. async def wrap_unix_datagram_socket(cls, sock: socket.socket) -> UNIXDatagramSocket:
  1041. trio_sock = trio.socket.from_stdlib_socket(sock)
  1042. return UNIXDatagramSocket(trio_sock)
  1043. @classmethod
  1044. async def wrap_connected_unix_datagram_socket(
  1045. cls, sock: socket.socket
  1046. ) -> ConnectedUNIXDatagramSocket:
  1047. trio_sock = trio.socket.from_stdlib_socket(sock)
  1048. return ConnectedUNIXDatagramSocket(trio_sock)
  1049. @classmethod
  1050. def current_default_thread_limiter(cls) -> CapacityLimiter:
  1051. try:
  1052. return _capacity_limiter_wrapper.get()
  1053. except LookupError:
  1054. limiter = CapacityLimiter(
  1055. original=trio.to_thread.current_default_thread_limiter()
  1056. )
  1057. _capacity_limiter_wrapper.set(limiter)
  1058. return limiter
  1059. @classmethod
  1060. def open_signal_receiver(
  1061. cls, *signals: Signals
  1062. ) -> AbstractContextManager[AsyncIterator[Signals]]:
  1063. return _SignalReceiver(signals)
  1064. @classmethod
  1065. def get_current_task(cls) -> TaskInfo:
  1066. task = current_task()
  1067. return TrioTaskInfo(task)
  1068. @classmethod
  1069. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  1070. root_task = current_root_task()
  1071. assert root_task
  1072. task_infos = [TrioTaskInfo(root_task)]
  1073. nurseries = root_task.child_nurseries
  1074. while nurseries:
  1075. new_nurseries: list[trio.Nursery] = []
  1076. for nursery in nurseries:
  1077. for task in nursery.child_tasks:
  1078. task_infos.append(TrioTaskInfo(task))
  1079. new_nurseries.extend(task.child_nurseries)
  1080. nurseries = new_nurseries
  1081. return task_infos
  1082. @classmethod
  1083. async def wait_all_tasks_blocked(cls) -> None:
  1084. from trio.testing import wait_all_tasks_blocked
  1085. await wait_all_tasks_blocked()
  1086. @classmethod
  1087. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  1088. return TestRunner(**options)
  1089. backend_class = TrioBackend