decorators.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. """A few useful function/method decorators."""
  5. from __future__ import annotations
  6. import functools
  7. import inspect
  8. import sys
  9. import warnings
  10. from collections.abc import Callable, Generator
  11. from typing import ParamSpec, TypeVar
  12. from astroid import util
  13. from astroid.context import InferenceContext
  14. from astroid.exceptions import InferenceError
  15. from astroid.typing import InferenceResult
  16. _R = TypeVar("_R")
  17. _P = ParamSpec("_P")
  18. def path_wrapper(func):
  19. """Return the given infer function wrapped to handle the path.
  20. Used to stop inference if the node has already been looked
  21. at for a given `InferenceContext` to prevent infinite recursion
  22. """
  23. @functools.wraps(func)
  24. def wrapped(
  25. node, context: InferenceContext | None = None, _func=func, **kwargs
  26. ) -> Generator:
  27. """Wrapper function handling context."""
  28. if context is None:
  29. context = InferenceContext()
  30. if context.push(node):
  31. return
  32. yielded = set()
  33. for res in _func(node, context, **kwargs):
  34. # unproxy only true instance, not const, tuple, dict...
  35. if res.__class__.__name__ == "Instance":
  36. ares = res._proxied
  37. else:
  38. ares = res
  39. if ares not in yielded:
  40. yield res
  41. yielded.add(ares)
  42. return wrapped
  43. def yes_if_nothing_inferred(
  44. func: Callable[_P, Generator[InferenceResult]],
  45. ) -> Callable[_P, Generator[InferenceResult]]:
  46. def inner(*args: _P.args, **kwargs: _P.kwargs) -> Generator[InferenceResult]:
  47. generator = func(*args, **kwargs)
  48. try:
  49. yield next(generator)
  50. except StopIteration:
  51. # generator is empty
  52. yield util.Uninferable
  53. return
  54. yield from generator
  55. return inner
  56. def raise_if_nothing_inferred(
  57. func: Callable[_P, Generator[InferenceResult]],
  58. ) -> Callable[_P, Generator[InferenceResult]]:
  59. def inner(*args: _P.args, **kwargs: _P.kwargs) -> Generator[InferenceResult]:
  60. generator = func(*args, **kwargs)
  61. try:
  62. yield next(generator)
  63. except StopIteration as error:
  64. # generator is empty
  65. if error.args:
  66. raise InferenceError(**error.args[0]) from error
  67. raise InferenceError(
  68. "StopIteration raised without any error information."
  69. ) from error
  70. except RecursionError as error:
  71. raise InferenceError(
  72. f"RecursionError raised with limit {sys.getrecursionlimit()}."
  73. ) from error
  74. yield from generator
  75. return inner
  76. # Expensive decorators only used to emit Deprecation warnings.
  77. # If no other than the default DeprecationWarning are enabled,
  78. # fall back to passthrough implementations.
  79. if util.check_warnings_filter(): # noqa: C901
  80. def deprecate_default_argument_values(
  81. astroid_version: str = "3.0", **arguments: str
  82. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  83. """Decorator which emits a DeprecationWarning if any arguments specified
  84. are None or not passed at all.
  85. Arguments should be a key-value mapping, with the key being the argument to check
  86. and the value being a type annotation as string for the value of the argument.
  87. To improve performance, only used when DeprecationWarnings other than
  88. the default one are enabled.
  89. """
  90. # Helpful links
  91. # Decorator for DeprecationWarning: https://stackoverflow.com/a/49802489
  92. # Typing of stacked decorators: https://stackoverflow.com/a/68290080
  93. def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
  94. """Decorator function."""
  95. @functools.wraps(func)
  96. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  97. """Emit DeprecationWarnings if conditions are met."""
  98. keys = list(inspect.signature(func).parameters.keys())
  99. for arg, type_annotation in arguments.items():
  100. try:
  101. index = keys.index(arg)
  102. except ValueError:
  103. raise ValueError(
  104. f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
  105. ) from None
  106. # pylint: disable = too-many-boolean-expressions
  107. if (
  108. # Check kwargs
  109. # - if found, check it's not None
  110. (arg in kwargs and kwargs[arg] is None)
  111. # Check args
  112. # - make sure not in kwargs
  113. # - len(args) needs to be long enough, if too short
  114. # arg can't be in args either
  115. # - args[index] should not be None
  116. or (
  117. arg not in kwargs
  118. and (
  119. index == -1
  120. or len(args) <= index
  121. or (len(args) > index and args[index] is None)
  122. )
  123. )
  124. ):
  125. warnings.warn(
  126. f"'{arg}' will be a required argument for "
  127. f"'{args[0].__class__.__qualname__}.{func.__name__}'"
  128. f" in astroid {astroid_version} "
  129. f"('{arg}' should be of type: '{type_annotation}')",
  130. DeprecationWarning,
  131. stacklevel=2,
  132. )
  133. return func(*args, **kwargs)
  134. return wrapper
  135. return deco
  136. def deprecate_arguments(
  137. astroid_version: str = "3.0", **arguments: str
  138. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  139. """Decorator which emits a DeprecationWarning if any arguments specified
  140. are passed.
  141. Arguments should be a key-value mapping, with the key being the argument to check
  142. and the value being a string that explains what to do instead of passing the argument.
  143. To improve performance, only used when DeprecationWarnings other than
  144. the default one are enabled.
  145. """
  146. def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
  147. @functools.wraps(func)
  148. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  149. keys = list(inspect.signature(func).parameters.keys())
  150. for arg, note in arguments.items():
  151. try:
  152. index = keys.index(arg)
  153. except ValueError:
  154. raise ValueError(
  155. f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
  156. ) from None
  157. if arg in kwargs or len(args) > index:
  158. warnings.warn(
  159. f"The argument '{arg}' for "
  160. f"'{args[0].__class__.__qualname__}.{func.__name__}' is deprecated "
  161. f"and will be removed in astroid {astroid_version} ({note})",
  162. DeprecationWarning,
  163. stacklevel=2,
  164. )
  165. return func(*args, **kwargs)
  166. return wrapper
  167. return deco
  168. else:
  169. def deprecate_default_argument_values(
  170. astroid_version: str = "3.0", **arguments: str
  171. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  172. """Passthrough decorator to improve performance if DeprecationWarnings are
  173. disabled.
  174. """
  175. def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
  176. """Decorator function."""
  177. return func
  178. return deco
  179. def deprecate_arguments(
  180. astroid_version: str = "3.0", **arguments: str
  181. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  182. """Passthrough decorator to improve performance if DeprecationWarnings are
  183. disabled.
  184. """
  185. def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
  186. """Decorator function."""
  187. return func
  188. return deco