nested_min_max.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Check for use of nested min/max functions."""
  5. from __future__ import annotations
  6. import copy
  7. from typing import TYPE_CHECKING
  8. from astroid import nodes, objects
  9. from astroid.const import Context
  10. from pylint.checkers import BaseChecker
  11. from pylint.checkers.utils import only_required_for_messages, safe_infer
  12. from pylint.interfaces import INFERENCE
  13. if TYPE_CHECKING:
  14. from pylint.lint import PyLinter
  15. DICT_TYPES = (
  16. objects.DictValues,
  17. objects.DictKeys,
  18. objects.DictItems,
  19. nodes.Dict,
  20. )
  21. class NestedMinMaxChecker(BaseChecker):
  22. """Multiple nested min/max calls on the same line will raise multiple messages.
  23. This behaviour is intended as it would slow down the checker to check
  24. for nested call with minimal benefits.
  25. """
  26. FUNC_NAMES = ("builtins.min", "builtins.max")
  27. name = "nested_min_max"
  28. msgs = {
  29. "W3301": (
  30. "Do not use nested call of '%s'; it's possible to do '%s' instead",
  31. "nested-min-max",
  32. "Nested calls ``min(1, min(2, 3))`` can be rewritten as ``min(1, 2, 3)``.",
  33. )
  34. }
  35. @classmethod
  36. def maybe_get_inferred_min_max_call(
  37. cls, node: nodes.Call
  38. ) -> nodes.FunctionDef | None:
  39. inferred = safe_infer(node.func)
  40. if (
  41. isinstance(inferred, nodes.FunctionDef)
  42. and inferred.qname() in cls.FUNC_NAMES
  43. ):
  44. return inferred
  45. return None
  46. @classmethod
  47. def get_redundant_calls(
  48. cls, node: nodes.Call, inferred_call: nodes.FunctionDef
  49. ) -> list[nodes.Call]:
  50. return [
  51. arg
  52. for arg in node.args
  53. if (
  54. isinstance(arg, nodes.Call)
  55. and (inferred := cls.maybe_get_inferred_min_max_call(arg))
  56. and inferred.qname == inferred_call.qname
  57. # Nesting is useful for finding the maximum in a matrix.
  58. # Allow: max(max([[1, 2, 3], [4, 5, 6]]))
  59. # Meaning, redundant call only if parent max call has more than 1 arg.
  60. and len(arg.parent.args) > 1
  61. )
  62. ]
  63. @only_required_for_messages("nested-min-max")
  64. def visit_call(self, node: nodes.Call) -> None:
  65. inferred = self.maybe_get_inferred_min_max_call(node)
  66. if inferred is None:
  67. return
  68. redundant_calls = self.get_redundant_calls(node, inferred)
  69. if not redundant_calls:
  70. return
  71. fixed_node = copy.copy(node)
  72. while len(redundant_calls) > 0:
  73. for i, arg in enumerate(fixed_node.args):
  74. # Exclude any calls with generator expressions as there is no
  75. # clear better suggestion for them.
  76. if isinstance(arg, nodes.Call) and any(
  77. isinstance(a, nodes.GeneratorExp) for a in arg.args
  78. ):
  79. return
  80. if arg in redundant_calls:
  81. fixed_node.args = (
  82. fixed_node.args[:i] + arg.args + fixed_node.args[i + 1 :]
  83. )
  84. break
  85. redundant_calls = self.get_redundant_calls(fixed_node, inferred)
  86. for idx, arg in enumerate(fixed_node.args):
  87. if not isinstance(arg, nodes.Const):
  88. if self._is_splattable_expression(arg):
  89. splat_node = nodes.Starred(
  90. ctx=Context.Load,
  91. lineno=arg.lineno,
  92. col_offset=0,
  93. parent=nodes.NodeNG(
  94. lineno=None,
  95. col_offset=None,
  96. end_lineno=None,
  97. end_col_offset=None,
  98. parent=None,
  99. ),
  100. end_lineno=0,
  101. end_col_offset=0,
  102. )
  103. splat_node.value = arg
  104. fixed_node.args = [
  105. *fixed_node.args[:idx],
  106. splat_node,
  107. *fixed_node.args[idx + 1 : idx],
  108. ]
  109. func_name = (
  110. node.func.attrname
  111. if isinstance(node.func, nodes.Attribute)
  112. else node.func.name
  113. )
  114. self.add_message(
  115. "nested-min-max",
  116. node=node,
  117. args=(func_name, fixed_node.as_string()),
  118. confidence=INFERENCE,
  119. )
  120. def _is_splattable_expression(self, arg: nodes.NodeNG) -> bool:
  121. """Returns true if expression under min/max could be converted to splat
  122. expression.
  123. """
  124. # Support sequence addition (operator __add__)
  125. if isinstance(arg, nodes.BinOp) and arg.op == "+":
  126. return self._is_splattable_expression(
  127. arg.left
  128. ) and self._is_splattable_expression(arg.right)
  129. # Support dict merge (operator __or__)
  130. if isinstance(arg, nodes.BinOp) and arg.op == "|":
  131. return self._is_splattable_expression(
  132. arg.left
  133. ) and self._is_splattable_expression(arg.right)
  134. inferred = safe_infer(arg)
  135. if inferred and inferred.pytype() in {"builtins.list", "builtins.tuple"}:
  136. return True
  137. if isinstance(
  138. inferred or arg,
  139. (
  140. nodes.List,
  141. nodes.Tuple,
  142. nodes.Set,
  143. nodes.ListComp,
  144. nodes.DictComp,
  145. *DICT_TYPES,
  146. ),
  147. ):
  148. return True
  149. return False
  150. def register(linter: PyLinter) -> None:
  151. linter.register_checker(NestedMinMaxChecker(linter))