transforms.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. import warnings
  6. from collections import defaultdict
  7. from collections.abc import Callable
  8. from typing import TYPE_CHECKING, TypeVar, Union, cast, overload
  9. from astroid.context import _invalidate_cache
  10. from astroid.typing import SuccessfulInferenceResult, TransformFn
  11. if TYPE_CHECKING:
  12. from astroid import nodes
  13. _SuccessfulInferenceResultT = TypeVar(
  14. "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
  15. )
  16. _Predicate = Callable[[_SuccessfulInferenceResultT], bool] | None
  17. # pylint: disable-next=consider-alternative-union-syntax
  18. _Vistables = Union[
  19. "nodes.NodeNG", list["nodes.NodeNG"], tuple["nodes.NodeNG", ...], str, None
  20. ]
  21. _VisitReturns = (
  22. SuccessfulInferenceResult
  23. | list[SuccessfulInferenceResult]
  24. | tuple[SuccessfulInferenceResult, ...]
  25. | str
  26. | None
  27. )
  28. class TransformVisitor:
  29. """A visitor for handling transforms.
  30. The standard approach of using it is to call
  31. :meth:`~visit` with an *astroid* module and the class
  32. will take care of the rest, walking the tree and running the
  33. transforms for each encountered node.
  34. Based on its usage in AstroidManager.brain, it should not be reinstantiated.
  35. """
  36. def __init__(self) -> None:
  37. # The typing here is incorrect, but it's the best we can do
  38. # Refer to register_transform and unregister_transform for the correct types
  39. self.transforms: defaultdict[
  40. type[SuccessfulInferenceResult],
  41. list[
  42. tuple[
  43. TransformFn[SuccessfulInferenceResult],
  44. _Predicate[SuccessfulInferenceResult],
  45. ]
  46. ],
  47. ] = defaultdict(list)
  48. def _transform(self, node: SuccessfulInferenceResult) -> SuccessfulInferenceResult:
  49. """Call matching transforms for the given node if any and return the
  50. transformed node.
  51. """
  52. cls = node.__class__
  53. for transform_func, predicate in self.transforms[cls]:
  54. if predicate is None or predicate(node):
  55. ret = transform_func(node)
  56. # if the transformation function returns something, it's
  57. # expected to be a replacement for the node
  58. if ret is not None:
  59. _invalidate_cache()
  60. node = ret
  61. if ret.__class__ != cls:
  62. # Can no longer apply the rest of the transforms.
  63. break
  64. return node
  65. def _visit(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
  66. for name in node._astroid_fields:
  67. value = getattr(node, name)
  68. if TYPE_CHECKING:
  69. value = cast(_Vistables, value)
  70. visited = self._visit_generic(value)
  71. if visited != value:
  72. setattr(node, name, visited)
  73. return self._transform(node)
  74. @overload
  75. def _visit_generic(self, node: None) -> None: ...
  76. @overload
  77. def _visit_generic(self, node: str) -> str: ...
  78. @overload
  79. def _visit_generic(
  80. self, node: list[nodes.NodeNG]
  81. ) -> list[SuccessfulInferenceResult]: ...
  82. @overload
  83. def _visit_generic(
  84. self, node: tuple[nodes.NodeNG, ...]
  85. ) -> tuple[SuccessfulInferenceResult, ...]: ...
  86. @overload
  87. def _visit_generic(self, node: nodes.NodeNG) -> SuccessfulInferenceResult: ...
  88. def _visit_generic(self, node: _Vistables) -> _VisitReturns:
  89. if not node:
  90. return node
  91. if isinstance(node, list):
  92. return [self._visit_generic(child) for child in node]
  93. if isinstance(node, tuple):
  94. return tuple(self._visit_generic(child) for child in node)
  95. if isinstance(node, str):
  96. return node
  97. try:
  98. return self._visit(node)
  99. except RecursionError:
  100. # Returning the node untransformed is better than giving up.
  101. warnings.warn(
  102. f"Astroid was unable to transform {node}.\n"
  103. "Some functionality will be missing unless the system recursion limit is lifted.\n"
  104. "From pylint, try: --init-hook='import sys; sys.setrecursionlimit(2000)' or higher.",
  105. UserWarning,
  106. stacklevel=0,
  107. )
  108. return node
  109. def register_transform(
  110. self,
  111. node_class: type[_SuccessfulInferenceResultT],
  112. transform: TransformFn[_SuccessfulInferenceResultT],
  113. predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
  114. ) -> None:
  115. """Register `transform(node)` function to be applied on the given node.
  116. The transform will only be applied if `predicate` is None or returns true
  117. when called with the node as argument.
  118. The transform function may return a value which is then used to
  119. substitute the original node in the tree.
  120. """
  121. self.transforms[node_class].append((transform, predicate)) # type: ignore[index, arg-type]
  122. def unregister_transform(
  123. self,
  124. node_class: type[_SuccessfulInferenceResultT],
  125. transform: TransformFn[_SuccessfulInferenceResultT],
  126. predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
  127. ) -> None:
  128. """Unregister the given transform."""
  129. self.transforms[node_class].remove((transform, predicate)) # type: ignore[index, arg-type]
  130. def visit(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
  131. """Walk the given astroid *tree* and transform each encountered node.
  132. Only the nodes which have transforms registered will actually
  133. be replaced or changed.
  134. """
  135. return self._visit(node)