utils.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """Utilities"""
  2. from __future__ import annotations
  3. import asyncio
  4. import sys
  5. import typing as t
  6. from collections.abc import Mapping
  7. from contextvars import copy_context
  8. from functools import partial, wraps
  9. if t.TYPE_CHECKING:
  10. from collections.abc import Callable
  11. from contextvars import Context
  12. class LazyDict(Mapping[str, t.Any]):
  13. """Lazy evaluated read-only dictionary.
  14. Initialised with a dictionary of key-value pairs where the values are either
  15. constants or callables. Callables are evaluated each time the respective item is
  16. read.
  17. """
  18. def __init__(self, dict):
  19. self._dict = dict
  20. def __getitem__(self, key):
  21. item = self._dict.get(key)
  22. return item() if callable(item) else item
  23. def __len__(self):
  24. return len(self._dict)
  25. def __iter__(self):
  26. return iter(self._dict)
  27. T = t.TypeVar("T")
  28. U = t.TypeVar("U")
  29. V = t.TypeVar("V")
  30. def _async_in_context(
  31. f: Callable[..., t.Coroutine[T, U, V]], context: Context | None = None
  32. ) -> Callable[..., t.Coroutine[T, U, V]]:
  33. """
  34. Wrapper to run a coroutine in a persistent ContextVar Context.
  35. Backports asyncio.create_task(context=...) behavior from Python 3.11
  36. """
  37. if context is None:
  38. context = copy_context()
  39. if sys.version_info >= (3, 11):
  40. @wraps(f)
  41. async def run_in_context(*args, **kwargs):
  42. coro = f(*args, **kwargs)
  43. return await asyncio.create_task(coro, context=context)
  44. return run_in_context
  45. # don't need this backport when we require 3.11
  46. # context_holder so we have a modifiable container for later calls
  47. context_holder = [context] # type: ignore[unreachable]
  48. async def preserve_context(f, *args, **kwargs):
  49. """call a coroutine, preserving the context after it is called"""
  50. try:
  51. return await f(*args, **kwargs)
  52. finally:
  53. # persist changes to the context for future calls
  54. context_holder[0] = copy_context()
  55. @wraps(f)
  56. async def run_in_context_pre311(*args, **kwargs):
  57. ctx = context_holder[0]
  58. return await ctx.run(partial(asyncio.create_task, preserve_context(f, *args, **kwargs)))
  59. return run_in_context_pre311