constraint.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. """Classes representing different types of constraints on inference values."""
  5. from __future__ import annotations
  6. import sys
  7. from abc import ABC, abstractmethod
  8. from collections.abc import Iterator
  9. from typing import TYPE_CHECKING
  10. from astroid import nodes, util
  11. from astroid.typing import InferenceResult
  12. if sys.version_info >= (3, 11):
  13. from typing import Self
  14. else:
  15. from typing_extensions import Self
  16. if TYPE_CHECKING:
  17. from astroid import bases
  18. _NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name
  19. class Constraint(ABC):
  20. """Represents a single constraint on a variable."""
  21. def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
  22. self.node = node
  23. """The node that this constraint applies to."""
  24. self.negate = negate
  25. """True if this constraint is negated. E.g., "is not" instead of "is"."""
  26. @classmethod
  27. @abstractmethod
  28. def match(
  29. cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
  30. ) -> Self | None:
  31. """Return a new constraint for node matched from expr, if expr matches
  32. the constraint pattern.
  33. If negate is True, negate the constraint.
  34. """
  35. @abstractmethod
  36. def satisfied_by(self, inferred: InferenceResult) -> bool:
  37. """Return True if this constraint is satisfied by the given inferred value."""
  38. class NoneConstraint(Constraint):
  39. """Represents an "is None" or "is not None" constraint."""
  40. CONST_NONE: nodes.Const = nodes.Const(None)
  41. @classmethod
  42. def match(
  43. cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
  44. ) -> Self | None:
  45. """Return a new constraint for node matched from expr, if expr matches
  46. the constraint pattern.
  47. Negate the constraint based on the value of negate.
  48. """
  49. if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
  50. left = expr.left
  51. op, right = expr.ops[0]
  52. if op in {"is", "is not"} and (
  53. _matches(left, node) and _matches(right, cls.CONST_NONE)
  54. ):
  55. negate = (op == "is" and negate) or (op == "is not" and not negate)
  56. return cls(node=node, negate=negate)
  57. return None
  58. def satisfied_by(self, inferred: InferenceResult) -> bool:
  59. """Return True if this constraint is satisfied by the given inferred value."""
  60. # Assume true if uninferable
  61. if isinstance(inferred, util.UninferableBase):
  62. return True
  63. # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
  64. return self.negate ^ _matches(inferred, self.CONST_NONE)
  65. class BooleanConstraint(Constraint):
  66. """Represents an "x" or "not x" constraint."""
  67. @classmethod
  68. def match(
  69. cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
  70. ) -> Self | None:
  71. """Return a new constraint for node if expr matches one of these patterns:
  72. - direct match (expr == node): use given negate value
  73. - negated match (expr == `not node`): flip negate value
  74. Return None if no pattern matches.
  75. """
  76. if _matches(expr, node):
  77. return cls(node=node, negate=negate)
  78. if (
  79. isinstance(expr, nodes.UnaryOp)
  80. and expr.op == "not"
  81. and _matches(expr.operand, node)
  82. ):
  83. return cls(node=node, negate=not negate)
  84. return None
  85. def satisfied_by(self, inferred: InferenceResult) -> bool:
  86. """Return True for uninferable results, or depending on negate flag:
  87. - negate=False: satisfied if boolean value is True
  88. - negate=True: satisfied if boolean value is False
  89. """
  90. inferred_booleaness = inferred.bool_value()
  91. if isinstance(inferred, util.UninferableBase) or isinstance(
  92. inferred_booleaness, util.UninferableBase
  93. ):
  94. return True
  95. return self.negate ^ inferred_booleaness
  96. def get_constraints(
  97. expr: _NameNodes, frame: nodes.LocalsDictNodeNG
  98. ) -> dict[nodes.If | nodes.IfExp, set[Constraint]]:
  99. """Returns the constraints for the given expression.
  100. The returned dictionary maps the node where the constraint was generated to the
  101. corresponding constraint(s).
  102. Constraints are computed statically by analysing the code surrounding expr.
  103. Currently this only supports constraints generated from if conditions.
  104. """
  105. current_node: nodes.NodeNG | None = expr
  106. constraints_mapping: dict[nodes.If | nodes.IfExp, set[Constraint]] = {}
  107. while current_node is not None and current_node is not frame:
  108. parent = current_node.parent
  109. if isinstance(parent, (nodes.If, nodes.IfExp)):
  110. branch, _ = parent.locate_child(current_node)
  111. constraints: set[Constraint] | None = None
  112. if branch == "body":
  113. constraints = set(_match_constraint(expr, parent.test))
  114. elif branch == "orelse":
  115. constraints = set(_match_constraint(expr, parent.test, invert=True))
  116. if constraints:
  117. constraints_mapping[parent] = constraints
  118. current_node = parent
  119. return constraints_mapping
  120. ALL_CONSTRAINT_CLASSES = frozenset(
  121. (
  122. NoneConstraint,
  123. BooleanConstraint,
  124. )
  125. )
  126. """All supported constraint types."""
  127. def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
  128. """Returns True if the two nodes match."""
  129. if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
  130. return node1.name == node2.name
  131. if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
  132. return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
  133. if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
  134. return node1.value == node2.value
  135. return False
  136. def _match_constraint(
  137. node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
  138. ) -> Iterator[Constraint]:
  139. """Yields all constraint patterns for node that match."""
  140. for constraint_cls in ALL_CONSTRAINT_CLASSES:
  141. constraint = constraint_cls.match(node, expr, invert)
  142. if constraint:
  143. yield constraint