_contextlib.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # mypy: allow-untyped-defs
  2. # Extra utilities for working with context managers that should have been
  3. # in the standard library but are not
  4. import functools
  5. import inspect
  6. import sys
  7. import warnings
  8. from collections.abc import Callable
  9. from typing import Any, cast, overload, TypeVar
  10. from typing_extensions import Self
  11. # Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
  12. # 'no_grad' and 'enable_grad').
  13. # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
  14. FuncType = Callable[..., Any]
  15. F = TypeVar("F", bound=FuncType)
  16. def _wrap_generator(ctx_factory, func):
  17. """
  18. Wrap each generator invocation with the context manager factory.
  19. The input should be a function that returns a context manager,
  20. not a context manager itself, to handle one-shot context managers.
  21. """
  22. @functools.wraps(func)
  23. def generator_context(*args, **kwargs):
  24. gen = func(*args, **kwargs)
  25. # Generators are suspended and unsuspended at `yield`, hence we
  26. # make sure the grad mode is properly set every time the execution
  27. # flow returns into the wrapped generator and restored when it
  28. # returns through our `yield` to our caller (see PR #49017).
  29. try:
  30. # Issuing `None` to a generator fires it up
  31. with ctx_factory():
  32. response = gen.send(None)
  33. while True:
  34. try:
  35. # Forward the response to our caller and get its next request
  36. request = yield response
  37. except GeneratorExit:
  38. # Inform the still active generator about its imminent closure
  39. with ctx_factory():
  40. gen.close()
  41. raise
  42. except BaseException: # noqa: B036
  43. # Propagate the exception thrown at us by the caller
  44. with ctx_factory():
  45. response = gen.throw(*sys.exc_info())
  46. else:
  47. # Pass the last request to the generator and get its response
  48. with ctx_factory():
  49. response = gen.send(request)
  50. # We let the exceptions raised above by the generator's `.throw` or
  51. # `.send` methods bubble up to our caller, except for StopIteration
  52. except StopIteration as e:
  53. # The generator informed us that it is done: take whatever its
  54. # returned value (if any) was and indicate that we're done too
  55. # by returning it (see docs for python's return-statement).
  56. return e.value
  57. return generator_context
  58. def context_decorator(ctx, func):
  59. """
  60. Like contextlib.ContextDecorator.
  61. But with the following differences:
  62. 1. Is done by wrapping, rather than inheritance, so it works with context
  63. managers that are implemented from C and thus cannot easily inherit from
  64. Python classes
  65. 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
  66. 3. Errors out if you try to wrap a class, because it is ambiguous whether
  67. or not you intended to wrap only the constructor
  68. The input argument can either be a context manager (in which case it must
  69. be a multi-shot context manager that can be directly invoked multiple times)
  70. or a callable that produces a context manager.
  71. """
  72. if callable(ctx) and hasattr(ctx, "__enter__"):
  73. raise AssertionError(
  74. f"Passed in {ctx} is both callable and also a valid context manager "
  75. "(has __enter__), making it ambiguous which interface to use. If you "
  76. "intended to pass a context manager factory, rewrite your call as "
  77. "context_decorator(lambda: ctx()); if you intended to pass a context "
  78. "manager directly, rewrite your call as context_decorator(lambda: ctx)"
  79. )
  80. if not callable(ctx):
  81. def ctx_factory():
  82. return ctx
  83. else:
  84. ctx_factory = ctx
  85. if inspect.isclass(func):
  86. raise RuntimeError(
  87. "Cannot decorate classes; it is ambiguous whether or not only the "
  88. "constructor or all methods should have the context manager applied; "
  89. "additionally, decorating a class at definition-site will prevent "
  90. "use of the identifier as a conventional type. "
  91. "To specify which methods to decorate, decorate each of them "
  92. "individually."
  93. )
  94. if inspect.isgeneratorfunction(func):
  95. return _wrap_generator(ctx_factory, func)
  96. @functools.wraps(func)
  97. def decorate_context(*args, **kwargs):
  98. # pyrefly: ignore [bad-context-manager]
  99. with ctx_factory():
  100. return func(*args, **kwargs)
  101. return decorate_context
  102. class _DecoratorContextManager:
  103. """Allow a context manager to be used as a decorator."""
  104. def __call__(self, orig_func: F) -> F:
  105. if inspect.isclass(orig_func):
  106. warnings.warn(
  107. "Decorating classes is deprecated and will be disabled in "
  108. "future versions. You should only decorate functions or methods. "
  109. "To preserve the current behavior of class decoration, you can "
  110. "directly decorate the `__init__` method and nothing else.",
  111. FutureWarning,
  112. stacklevel=2,
  113. )
  114. func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
  115. else:
  116. func = orig_func
  117. return cast(F, context_decorator(self.clone, func))
  118. def __enter__(self) -> None:
  119. raise NotImplementedError
  120. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  121. raise NotImplementedError
  122. def clone(self):
  123. # override this method if your children class takes __init__ parameters
  124. return self.__class__()
  125. class _NoParamDecoratorContextManager(_DecoratorContextManager):
  126. """Allow a context manager to be used as a decorator without parentheses."""
  127. @overload
  128. def __new__(cls, orig_func: F) -> F: ... # type: ignore[misc]
  129. @overload
  130. def __new__(cls, orig_func: None = None) -> Self: ...
  131. def __new__(cls, orig_func: F | None = None) -> Self | F: # type: ignore[misc]
  132. if orig_func is None:
  133. return super().__new__(cls)
  134. return cls()(orig_func)