| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
- # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
- # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
- """Check for use of nested min/max functions."""
- from __future__ import annotations
- import copy
- from typing import TYPE_CHECKING
- from astroid import nodes, objects
- from astroid.const import Context
- from pylint.checkers import BaseChecker
- from pylint.checkers.utils import only_required_for_messages, safe_infer
- from pylint.interfaces import INFERENCE
- if TYPE_CHECKING:
- from pylint.lint import PyLinter
- DICT_TYPES = (
- objects.DictValues,
- objects.DictKeys,
- objects.DictItems,
- nodes.Dict,
- )
- class NestedMinMaxChecker(BaseChecker):
- """Multiple nested min/max calls on the same line will raise multiple messages.
- This behaviour is intended as it would slow down the checker to check
- for nested call with minimal benefits.
- """
- FUNC_NAMES = ("builtins.min", "builtins.max")
- name = "nested_min_max"
- msgs = {
- "W3301": (
- "Do not use nested call of '%s'; it's possible to do '%s' instead",
- "nested-min-max",
- "Nested calls ``min(1, min(2, 3))`` can be rewritten as ``min(1, 2, 3)``.",
- )
- }
- @classmethod
- def maybe_get_inferred_min_max_call(
- cls, node: nodes.Call
- ) -> nodes.FunctionDef | None:
- inferred = safe_infer(node.func)
- if (
- isinstance(inferred, nodes.FunctionDef)
- and inferred.qname() in cls.FUNC_NAMES
- ):
- return inferred
- return None
- @classmethod
- def get_redundant_calls(
- cls, node: nodes.Call, inferred_call: nodes.FunctionDef
- ) -> list[nodes.Call]:
- return [
- arg
- for arg in node.args
- if (
- isinstance(arg, nodes.Call)
- and (inferred := cls.maybe_get_inferred_min_max_call(arg))
- and inferred.qname == inferred_call.qname
- # Nesting is useful for finding the maximum in a matrix.
- # Allow: max(max([[1, 2, 3], [4, 5, 6]]))
- # Meaning, redundant call only if parent max call has more than 1 arg.
- and len(arg.parent.args) > 1
- )
- ]
- @only_required_for_messages("nested-min-max")
- def visit_call(self, node: nodes.Call) -> None:
- inferred = self.maybe_get_inferred_min_max_call(node)
- if inferred is None:
- return
- redundant_calls = self.get_redundant_calls(node, inferred)
- if not redundant_calls:
- return
- fixed_node = copy.copy(node)
- while len(redundant_calls) > 0:
- for i, arg in enumerate(fixed_node.args):
- # Exclude any calls with generator expressions as there is no
- # clear better suggestion for them.
- if isinstance(arg, nodes.Call) and any(
- isinstance(a, nodes.GeneratorExp) for a in arg.args
- ):
- return
- if arg in redundant_calls:
- fixed_node.args = (
- fixed_node.args[:i] + arg.args + fixed_node.args[i + 1 :]
- )
- break
- redundant_calls = self.get_redundant_calls(fixed_node, inferred)
- for idx, arg in enumerate(fixed_node.args):
- if not isinstance(arg, nodes.Const):
- if self._is_splattable_expression(arg):
- splat_node = nodes.Starred(
- ctx=Context.Load,
- lineno=arg.lineno,
- col_offset=0,
- parent=nodes.NodeNG(
- lineno=None,
- col_offset=None,
- end_lineno=None,
- end_col_offset=None,
- parent=None,
- ),
- end_lineno=0,
- end_col_offset=0,
- )
- splat_node.value = arg
- fixed_node.args = [
- *fixed_node.args[:idx],
- splat_node,
- *fixed_node.args[idx + 1 : idx],
- ]
- func_name = (
- node.func.attrname
- if isinstance(node.func, nodes.Attribute)
- else node.func.name
- )
- self.add_message(
- "nested-min-max",
- node=node,
- args=(func_name, fixed_node.as_string()),
- confidence=INFERENCE,
- )
- def _is_splattable_expression(self, arg: nodes.NodeNG) -> bool:
- """Returns true if expression under min/max could be converted to splat
- expression.
- """
- # Support sequence addition (operator __add__)
- if isinstance(arg, nodes.BinOp) and arg.op == "+":
- return self._is_splattable_expression(
- arg.left
- ) and self._is_splattable_expression(arg.right)
- # Support dict merge (operator __or__)
- if isinstance(arg, nodes.BinOp) and arg.op == "|":
- return self._is_splattable_expression(
- arg.left
- ) and self._is_splattable_expression(arg.right)
- inferred = safe_infer(arg)
- if inferred and inferred.pytype() in {"builtins.list", "builtins.tuple"}:
- return True
- if isinstance(
- inferred or arg,
- (
- nodes.List,
- nodes.Tuple,
- nodes.Set,
- nodes.ListComp,
- nodes.DictComp,
- *DICT_TYPES,
- ),
- ):
- return True
- return False
- def register(linter: PyLinter) -> None:
- linter.register_checker(NestedMinMaxChecker(linter))
|