| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- import asyncio
- import dataclasses
- import inspect
- import random
- import sys
- import warnings
- from functools import _CacheInfo, _make_key, partial, partialmethod
- from typing import (
- Any,
- Callable,
- Coroutine,
- Generic,
- Hashable,
- List,
- Optional,
- OrderedDict,
- Type,
- TypedDict,
- TypeVar,
- Union,
- cast,
- final,
- overload,
- )
- if sys.version_info >= (3, 11):
- from typing import Self
- else:
- from typing_extensions import Self
- if sys.version_info < (3, 14):
- from asyncio.coroutines import _is_coroutine # type: ignore[attr-defined]
- __version__ = "2.3.0"
- __all__ = ("AlruCacheLoopResetWarning", "alru_cache")
- _T = TypeVar("_T")
- _R = TypeVar("_R")
- _Coro = Coroutine[Any, Any, _R]
- _CB = Callable[..., _Coro[_R]]
- _CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
- class AlruCacheLoopResetWarning(UserWarning):
- """Emitted once per cache instance when a loop change triggers an auto-reset."""
- @final
- class _CacheParameters(TypedDict):
- typed: bool
- maxsize: Optional[int]
- tasks: int
- closed: bool
- @final
- @dataclasses.dataclass
- class _CacheItem(Generic[_R]):
- task: "asyncio.Task[_R]"
- later_call: Optional[asyncio.Handle]
- waiters: int
- def cancel(self) -> None:
- if self.later_call is not None:
- self.later_call.cancel()
- self.later_call = None
- @final
- class _LRUCacheWrapper(Generic[_R]):
- def __init__(
- self,
- fn: _CB[_R],
- maxsize: Optional[int],
- typed: bool,
- ttl: Optional[float],
- jitter: Optional[float],
- ) -> None:
- try:
- self.__module__ = fn.__module__
- except AttributeError:
- pass
- try:
- self.__name__ = fn.__name__
- except AttributeError:
- pass
- try:
- self.__qualname__ = fn.__qualname__
- except AttributeError:
- pass
- try:
- self.__doc__ = fn.__doc__
- except AttributeError:
- pass
- try:
- self.__annotations__ = fn.__annotations__
- except AttributeError:
- pass
- try:
- self.__dict__.update(fn.__dict__)
- except AttributeError:
- pass
- # set __wrapped__ last so we don't inadvertently copy it
- # from the wrapped function when updating __dict__
- if sys.version_info < (3, 14):
- self._is_coroutine = _is_coroutine
- self.__wrapped__ = fn
- self.__maxsize = maxsize
- self.__typed = typed
- self.__ttl = ttl
- self.__jitter = jitter
- self.__cache: OrderedDict[Hashable, _CacheItem[_R]] = OrderedDict()
- self.__closed = False
- self.__hits = 0
- self.__misses = 0
- self.__first_loop: Optional[asyncio.AbstractEventLoop] = None
- self.__warned_loop_reset = False
- @property
- def __tasks(self) -> List["asyncio.Task[_R]"]:
- # NOTE: I don't think we need to form a set first here but not
- # too sure we want it for guarantees
- return list(
- {
- cache_item.task
- for cache_item in self.__cache.values()
- if not cache_item.task.done()
- }
- )
- def _check_loop(self, loop: asyncio.AbstractEventLoop) -> None:
- if self.__first_loop is None:
- self.__first_loop = loop
- elif self.__first_loop is not loop:
- if not self.__warned_loop_reset:
- warnings.warn(
- "alru_cache detected event loop change and auto-cleared "
- "stale entries. This is safe but unusual outside of "
- "tests (pytest-anyio, etc.).",
- AlruCacheLoopResetWarning,
- stacklevel=3,
- )
- self.__warned_loop_reset = True
- # Old cache entries hold tasks/handles bound to the previous
- # loop and are invalid here. Clear and rebind.
- self.cache_clear()
- self.__first_loop = loop
- def cache_contains(self, /, *args: Hashable, **kwargs: Any) -> bool:
- """Check if the given arguments are in the cache.
- Does not affect hit/miss counters or LRU ordering.
- """
- key = _make_key(args, kwargs, self.__typed)
- return key in self.__cache
- def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
- key = _make_key(args, kwargs, self.__typed)
- cache_item = self.__cache.pop(key, None)
- if cache_item is None:
- return False
- else:
- cache_item.cancel()
- return True
- def cache_clear(self) -> None:
- self.__hits = 0
- self.__misses = 0
- for c in self.__cache.values():
- if c.later_call:
- c.later_call.cancel()
- self.__cache.clear()
- async def cache_close(self, *, wait: bool = False) -> None:
- self.__closed = True
- tasks = self.__tasks
- if not tasks:
- return
- if not wait:
- for task in tasks:
- if not task.done():
- task.cancel()
- await asyncio.gather(*tasks, return_exceptions=True)
- def cache_info(self) -> _CacheInfo:
- return _CacheInfo(
- self.__hits,
- self.__misses,
- self.__maxsize,
- len(self.__cache),
- )
- def cache_parameters(self) -> _CacheParameters:
- return _CacheParameters(
- maxsize=self.__maxsize,
- typed=self.__typed,
- tasks=len(self.__tasks),
- closed=self.__closed,
- )
- def _cache_hit(self, key: Hashable) -> None:
- self.__hits += 1
- self.__cache.move_to_end(key)
- def _cache_miss(self, key: Hashable) -> None:
- self.__misses += 1
- def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None:
- # We must use the private attribute instead of `exception()`
- # so asyncio does not set `task.__log_traceback = False` on
- # the false assumption that the caller read the task Exception
- if task.cancelled() or task._exception is not None:
- self.__cache.pop(key, None)
- return
- cache_item = self.__cache.get(key)
- if self.__ttl is not None and cache_item is not None:
- effective_ttl = self.__ttl
- if self.__jitter is not None:
- effective_ttl += random.uniform(0, self.__jitter)
- loop = asyncio.get_running_loop()
- cache_item.later_call = loop.call_later(
- effective_ttl, self.__cache.pop, key, None
- )
- async def _shield_and_handle_cancelled_error(
- self, cache_item: _CacheItem[_T], key: Hashable
- ) -> _T:
- task = cache_item.task
- try:
- # All waiters await the same shielded task.
- return await asyncio.shield(task)
- except asyncio.CancelledError:
- # If this is the last waiter and the underlying task is not done,
- # cancel the underlying task and remove the cache entry.
- if cache_item.waiters == 1 and not task.done():
- cache_item.cancel() # Cancel TTL expiration
- task.cancel() # Cancel the running coroutine
- self.__cache.pop(key, None) # Remove from cache
- raise
- finally:
- # Each logical waiter decrements waiters on exit (normal or cancelled).
- cache_item.waiters -= 1
- async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
- if self.__closed:
- raise RuntimeError(f"alru_cache is closed for {self}")
- loop = asyncio.get_running_loop()
- self._check_loop(loop)
- key = _make_key(fn_args, fn_kwargs, self.__typed)
- cache_item = self.__cache.get(key)
- if cache_item is not None:
- self._cache_hit(key)
- if not cache_item.task.done():
- # Each logical waiter increments waiters on entry.
- cache_item.waiters += 1
- return await self._shield_and_handle_cancelled_error(cache_item, key)
- # If the task is already done, just return the result.
- return cache_item.task.result()
- coro = self.__wrapped__(*fn_args, **fn_kwargs)
- task: asyncio.Task[_R] = loop.create_task(coro)
- task.add_done_callback(partial(self._task_done_callback, key))
- cache_item = _CacheItem(task, None, 1)
- self.__cache[key] = cache_item
- if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
- dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
- dropped_cache_item.cancel()
- self._cache_miss(key)
- return await self._shield_and_handle_cancelled_error(cache_item, key)
- def __get__(
- self, instance: _T, owner: Optional[Type[_T]]
- ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
- if owner is None:
- return self
- else:
- return _LRUCacheWrapperInstanceMethod(self, instance)
- @final
- class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
- def __init__(
- self,
- wrapper: _LRUCacheWrapper[_R],
- instance: _T,
- ) -> None:
- try:
- self.__module__ = wrapper.__module__
- except AttributeError:
- pass
- try:
- self.__name__ = wrapper.__name__
- except AttributeError:
- pass
- try:
- self.__qualname__ = wrapper.__qualname__
- except AttributeError:
- pass
- try:
- self.__doc__ = wrapper.__doc__
- except AttributeError:
- pass
- try:
- self.__annotations__ = wrapper.__annotations__
- except AttributeError:
- pass
- try:
- self.__dict__.update(wrapper.__dict__)
- except AttributeError:
- pass
- # set __wrapped__ last so we don't inadvertently copy it
- # from the wrapped function when updating __dict__
- if sys.version_info < (3, 14):
- self._is_coroutine = _is_coroutine
- self.__wrapped__ = wrapper.__wrapped__
- self.__instance = instance
- self.__wrapper = wrapper
- def cache_contains(self, /, *args: Hashable, **kwargs: Any) -> bool:
- return self.__wrapper.cache_contains(self.__instance, *args, **kwargs)
- def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
- return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)
- def cache_clear(self) -> None:
- self.__wrapper.cache_clear()
- async def cache_close(
- self,
- *,
- wait: bool = False,
- cancel: bool = False,
- return_exceptions: bool = True,
- ) -> None:
- if cancel or return_exceptions is not True:
- warnings.warn(
- "cancel/return_exceptions are deprecated; use wait=True to allow tasks "
- "to finish and wait=False to cancel pending tasks.",
- DeprecationWarning,
- stacklevel=2,
- )
- await self.__wrapper.cache_close(wait=wait)
- def cache_info(self) -> _CacheInfo:
- return self.__wrapper.cache_info()
- def cache_parameters(self) -> _CacheParameters:
- return self.__wrapper.cache_parameters()
- async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
- return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
- def _make_wrapper(
- maxsize: Optional[int],
- typed: bool,
- ttl: Optional[float] = None,
- jitter: Optional[float] = None,
- ) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
- if jitter is not None and ttl is None:
- raise ValueError("jitter requires ttl to be set")
- if jitter is not None and jitter < 0:
- raise ValueError("jitter must be non-negative")
- def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
- origin = fn
- while isinstance(origin, (partial, partialmethod)):
- origin = origin.func
- if not inspect.iscoroutinefunction(origin):
- raise RuntimeError(f"Coroutine function is required, got {fn!r}")
- if hasattr(fn, "_make_unbound_method"):
- fn = fn._make_unbound_method()
- wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl, jitter)
- if sys.version_info >= (3, 12):
- wrapper = inspect.markcoroutinefunction(wrapper)
- return wrapper
- return wrapper
- @overload
- def alru_cache(
- maxsize: Optional[int] = 128,
- typed: bool = False,
- *,
- ttl: Optional[float] = None,
- jitter: Optional[float] = None,
- ) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
- ...
- @overload
- def alru_cache(
- maxsize: _CBP[_R],
- /,
- ) -> _LRUCacheWrapper[_R]:
- ...
- def alru_cache(
- maxsize: Union[Optional[int], _CBP[_R]] = 128,
- typed: bool = False,
- *,
- ttl: Optional[float] = None,
- jitter: Optional[float] = None,
- ) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
- if maxsize is None or isinstance(maxsize, int):
- return _make_wrapper(maxsize, typed, ttl, jitter)
- else:
- fn = cast(_CB[_R], maxsize)
- if callable(fn) or hasattr(fn, "_make_unbound_method"):
- return _make_wrapper(128, False, None, None)(fn)
- raise NotImplementedError(f"{fn!r} decorating is not supported")
|