__init__.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import asyncio
  2. import dataclasses
  3. import inspect
  4. import random
  5. import sys
  6. import warnings
  7. from functools import _CacheInfo, _make_key, partial, partialmethod
  8. from typing import (
  9. Any,
  10. Callable,
  11. Coroutine,
  12. Generic,
  13. Hashable,
  14. List,
  15. Optional,
  16. OrderedDict,
  17. Type,
  18. TypedDict,
  19. TypeVar,
  20. Union,
  21. cast,
  22. final,
  23. overload,
  24. )
  25. if sys.version_info >= (3, 11):
  26. from typing import Self
  27. else:
  28. from typing_extensions import Self
  29. if sys.version_info < (3, 14):
  30. from asyncio.coroutines import _is_coroutine # type: ignore[attr-defined]
  31. __version__ = "2.3.0"
  32. __all__ = ("AlruCacheLoopResetWarning", "alru_cache")
  33. _T = TypeVar("_T")
  34. _R = TypeVar("_R")
  35. _Coro = Coroutine[Any, Any, _R]
  36. _CB = Callable[..., _Coro[_R]]
  37. _CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
  38. class AlruCacheLoopResetWarning(UserWarning):
  39. """Emitted once per cache instance when a loop change triggers an auto-reset."""
  40. @final
  41. class _CacheParameters(TypedDict):
  42. typed: bool
  43. maxsize: Optional[int]
  44. tasks: int
  45. closed: bool
  46. @final
  47. @dataclasses.dataclass
  48. class _CacheItem(Generic[_R]):
  49. task: "asyncio.Task[_R]"
  50. later_call: Optional[asyncio.Handle]
  51. waiters: int
  52. def cancel(self) -> None:
  53. if self.later_call is not None:
  54. self.later_call.cancel()
  55. self.later_call = None
  56. @final
  57. class _LRUCacheWrapper(Generic[_R]):
  58. def __init__(
  59. self,
  60. fn: _CB[_R],
  61. maxsize: Optional[int],
  62. typed: bool,
  63. ttl: Optional[float],
  64. jitter: Optional[float],
  65. ) -> None:
  66. try:
  67. self.__module__ = fn.__module__
  68. except AttributeError:
  69. pass
  70. try:
  71. self.__name__ = fn.__name__
  72. except AttributeError:
  73. pass
  74. try:
  75. self.__qualname__ = fn.__qualname__
  76. except AttributeError:
  77. pass
  78. try:
  79. self.__doc__ = fn.__doc__
  80. except AttributeError:
  81. pass
  82. try:
  83. self.__annotations__ = fn.__annotations__
  84. except AttributeError:
  85. pass
  86. try:
  87. self.__dict__.update(fn.__dict__)
  88. except AttributeError:
  89. pass
  90. # set __wrapped__ last so we don't inadvertently copy it
  91. # from the wrapped function when updating __dict__
  92. if sys.version_info < (3, 14):
  93. self._is_coroutine = _is_coroutine
  94. self.__wrapped__ = fn
  95. self.__maxsize = maxsize
  96. self.__typed = typed
  97. self.__ttl = ttl
  98. self.__jitter = jitter
  99. self.__cache: OrderedDict[Hashable, _CacheItem[_R]] = OrderedDict()
  100. self.__closed = False
  101. self.__hits = 0
  102. self.__misses = 0
  103. self.__first_loop: Optional[asyncio.AbstractEventLoop] = None
  104. self.__warned_loop_reset = False
  105. @property
  106. def __tasks(self) -> List["asyncio.Task[_R]"]:
  107. # NOTE: I don't think we need to form a set first here but not
  108. # too sure we want it for guarantees
  109. return list(
  110. {
  111. cache_item.task
  112. for cache_item in self.__cache.values()
  113. if not cache_item.task.done()
  114. }
  115. )
  116. def _check_loop(self, loop: asyncio.AbstractEventLoop) -> None:
  117. if self.__first_loop is None:
  118. self.__first_loop = loop
  119. elif self.__first_loop is not loop:
  120. if not self.__warned_loop_reset:
  121. warnings.warn(
  122. "alru_cache detected event loop change and auto-cleared "
  123. "stale entries. This is safe but unusual outside of "
  124. "tests (pytest-anyio, etc.).",
  125. AlruCacheLoopResetWarning,
  126. stacklevel=3,
  127. )
  128. self.__warned_loop_reset = True
  129. # Old cache entries hold tasks/handles bound to the previous
  130. # loop and are invalid here. Clear and rebind.
  131. self.cache_clear()
  132. self.__first_loop = loop
  133. def cache_contains(self, /, *args: Hashable, **kwargs: Any) -> bool:
  134. """Check if the given arguments are in the cache.
  135. Does not affect hit/miss counters or LRU ordering.
  136. """
  137. key = _make_key(args, kwargs, self.__typed)
  138. return key in self.__cache
  139. def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
  140. key = _make_key(args, kwargs, self.__typed)
  141. cache_item = self.__cache.pop(key, None)
  142. if cache_item is None:
  143. return False
  144. else:
  145. cache_item.cancel()
  146. return True
  147. def cache_clear(self) -> None:
  148. self.__hits = 0
  149. self.__misses = 0
  150. for c in self.__cache.values():
  151. if c.later_call:
  152. c.later_call.cancel()
  153. self.__cache.clear()
  154. async def cache_close(self, *, wait: bool = False) -> None:
  155. self.__closed = True
  156. tasks = self.__tasks
  157. if not tasks:
  158. return
  159. if not wait:
  160. for task in tasks:
  161. if not task.done():
  162. task.cancel()
  163. await asyncio.gather(*tasks, return_exceptions=True)
  164. def cache_info(self) -> _CacheInfo:
  165. return _CacheInfo(
  166. self.__hits,
  167. self.__misses,
  168. self.__maxsize,
  169. len(self.__cache),
  170. )
  171. def cache_parameters(self) -> _CacheParameters:
  172. return _CacheParameters(
  173. maxsize=self.__maxsize,
  174. typed=self.__typed,
  175. tasks=len(self.__tasks),
  176. closed=self.__closed,
  177. )
  178. def _cache_hit(self, key: Hashable) -> None:
  179. self.__hits += 1
  180. self.__cache.move_to_end(key)
  181. def _cache_miss(self, key: Hashable) -> None:
  182. self.__misses += 1
  183. def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None:
  184. # We must use the private attribute instead of `exception()`
  185. # so asyncio does not set `task.__log_traceback = False` on
  186. # the false assumption that the caller read the task Exception
  187. if task.cancelled() or task._exception is not None:
  188. self.__cache.pop(key, None)
  189. return
  190. cache_item = self.__cache.get(key)
  191. if self.__ttl is not None and cache_item is not None:
  192. effective_ttl = self.__ttl
  193. if self.__jitter is not None:
  194. effective_ttl += random.uniform(0, self.__jitter)
  195. loop = asyncio.get_running_loop()
  196. cache_item.later_call = loop.call_later(
  197. effective_ttl, self.__cache.pop, key, None
  198. )
  199. async def _shield_and_handle_cancelled_error(
  200. self, cache_item: _CacheItem[_T], key: Hashable
  201. ) -> _T:
  202. task = cache_item.task
  203. try:
  204. # All waiters await the same shielded task.
  205. return await asyncio.shield(task)
  206. except asyncio.CancelledError:
  207. # If this is the last waiter and the underlying task is not done,
  208. # cancel the underlying task and remove the cache entry.
  209. if cache_item.waiters == 1 and not task.done():
  210. cache_item.cancel() # Cancel TTL expiration
  211. task.cancel() # Cancel the running coroutine
  212. self.__cache.pop(key, None) # Remove from cache
  213. raise
  214. finally:
  215. # Each logical waiter decrements waiters on exit (normal or cancelled).
  216. cache_item.waiters -= 1
  217. async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
  218. if self.__closed:
  219. raise RuntimeError(f"alru_cache is closed for {self}")
  220. loop = asyncio.get_running_loop()
  221. self._check_loop(loop)
  222. key = _make_key(fn_args, fn_kwargs, self.__typed)
  223. cache_item = self.__cache.get(key)
  224. if cache_item is not None:
  225. self._cache_hit(key)
  226. if not cache_item.task.done():
  227. # Each logical waiter increments waiters on entry.
  228. cache_item.waiters += 1
  229. return await self._shield_and_handle_cancelled_error(cache_item, key)
  230. # If the task is already done, just return the result.
  231. return cache_item.task.result()
  232. coro = self.__wrapped__(*fn_args, **fn_kwargs)
  233. task: asyncio.Task[_R] = loop.create_task(coro)
  234. task.add_done_callback(partial(self._task_done_callback, key))
  235. cache_item = _CacheItem(task, None, 1)
  236. self.__cache[key] = cache_item
  237. if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
  238. dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
  239. dropped_cache_item.cancel()
  240. self._cache_miss(key)
  241. return await self._shield_and_handle_cancelled_error(cache_item, key)
  242. def __get__(
  243. self, instance: _T, owner: Optional[Type[_T]]
  244. ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
  245. if owner is None:
  246. return self
  247. else:
  248. return _LRUCacheWrapperInstanceMethod(self, instance)
  249. @final
  250. class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
  251. def __init__(
  252. self,
  253. wrapper: _LRUCacheWrapper[_R],
  254. instance: _T,
  255. ) -> None:
  256. try:
  257. self.__module__ = wrapper.__module__
  258. except AttributeError:
  259. pass
  260. try:
  261. self.__name__ = wrapper.__name__
  262. except AttributeError:
  263. pass
  264. try:
  265. self.__qualname__ = wrapper.__qualname__
  266. except AttributeError:
  267. pass
  268. try:
  269. self.__doc__ = wrapper.__doc__
  270. except AttributeError:
  271. pass
  272. try:
  273. self.__annotations__ = wrapper.__annotations__
  274. except AttributeError:
  275. pass
  276. try:
  277. self.__dict__.update(wrapper.__dict__)
  278. except AttributeError:
  279. pass
  280. # set __wrapped__ last so we don't inadvertently copy it
  281. # from the wrapped function when updating __dict__
  282. if sys.version_info < (3, 14):
  283. self._is_coroutine = _is_coroutine
  284. self.__wrapped__ = wrapper.__wrapped__
  285. self.__instance = instance
  286. self.__wrapper = wrapper
  287. def cache_contains(self, /, *args: Hashable, **kwargs: Any) -> bool:
  288. return self.__wrapper.cache_contains(self.__instance, *args, **kwargs)
  289. def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
  290. return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)
  291. def cache_clear(self) -> None:
  292. self.__wrapper.cache_clear()
  293. async def cache_close(
  294. self,
  295. *,
  296. wait: bool = False,
  297. cancel: bool = False,
  298. return_exceptions: bool = True,
  299. ) -> None:
  300. if cancel or return_exceptions is not True:
  301. warnings.warn(
  302. "cancel/return_exceptions are deprecated; use wait=True to allow tasks "
  303. "to finish and wait=False to cancel pending tasks.",
  304. DeprecationWarning,
  305. stacklevel=2,
  306. )
  307. await self.__wrapper.cache_close(wait=wait)
  308. def cache_info(self) -> _CacheInfo:
  309. return self.__wrapper.cache_info()
  310. def cache_parameters(self) -> _CacheParameters:
  311. return self.__wrapper.cache_parameters()
  312. async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
  313. return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
  314. def _make_wrapper(
  315. maxsize: Optional[int],
  316. typed: bool,
  317. ttl: Optional[float] = None,
  318. jitter: Optional[float] = None,
  319. ) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
  320. if jitter is not None and ttl is None:
  321. raise ValueError("jitter requires ttl to be set")
  322. if jitter is not None and jitter < 0:
  323. raise ValueError("jitter must be non-negative")
  324. def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
  325. origin = fn
  326. while isinstance(origin, (partial, partialmethod)):
  327. origin = origin.func
  328. if not inspect.iscoroutinefunction(origin):
  329. raise RuntimeError(f"Coroutine function is required, got {fn!r}")
  330. if hasattr(fn, "_make_unbound_method"):
  331. fn = fn._make_unbound_method()
  332. wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl, jitter)
  333. if sys.version_info >= (3, 12):
  334. wrapper = inspect.markcoroutinefunction(wrapper)
  335. return wrapper
  336. return wrapper
  337. @overload
  338. def alru_cache(
  339. maxsize: Optional[int] = 128,
  340. typed: bool = False,
  341. *,
  342. ttl: Optional[float] = None,
  343. jitter: Optional[float] = None,
  344. ) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
  345. ...
  346. @overload
  347. def alru_cache(
  348. maxsize: _CBP[_R],
  349. /,
  350. ) -> _LRUCacheWrapper[_R]:
  351. ...
  352. def alru_cache(
  353. maxsize: Union[Optional[int], _CBP[_R]] = 128,
  354. typed: bool = False,
  355. *,
  356. ttl: Optional[float] = None,
  357. jitter: Optional[float] = None,
  358. ) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
  359. if maxsize is None or isinstance(maxsize, int):
  360. return _make_wrapper(maxsize, typed, ttl, jitter)
  361. else:
  362. fn = cast(_CB[_R], maxsize)
  363. if callable(fn) or hasattr(fn, "_make_unbound_method"):
  364. return _make_wrapper(128, False, None, None)(fn)
  365. raise NotImplementedError(f"{fn!r} decorating is not supported")