star_args.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. This module is responsible for inferring *args and **kwargs for signatures.
  3. This means for example in this case::
  4. def foo(a, b, c): ...
  5. def bar(*args):
  6. return foo(1, *args)
  7. The signature here for bar should be `bar(b, c)` instead of bar(*args).
  8. """
  9. from inspect import Parameter
  10. from parso import tree
  11. from jedi.inference.utils import to_list
  12. from jedi.inference.names import ParamNameWrapper
  13. from jedi.inference.helpers import is_big_annoying_library
  14. def _iter_nodes_for_param(param_name):
  15. from parso.python.tree import search_ancestor
  16. from jedi.inference.arguments import TreeArguments
  17. execution_context = param_name.parent_context
  18. # Walk up the parso tree to get the FunctionNode we want. We use the parso
  19. # tree rather than going via the execution context so that we're agnostic of
  20. # the specific scope we're evaluating within (i.e: module or function,
  21. # etc.).
  22. function_node = tree.search_ancestor(param_name.tree_name, 'funcdef', 'lambdef')
  23. module_node = function_node.get_root_node()
  24. start = function_node.children[-1].start_pos
  25. end = function_node.children[-1].end_pos
  26. for name in module_node.get_used_names().get(param_name.string_name):
  27. if start <= name.start_pos < end:
  28. # Is used in the function
  29. argument = name.parent
  30. if argument.type == 'argument' \
  31. and argument.children[0] == '*' * param_name.star_count:
  32. trailer = search_ancestor(argument, 'trailer')
  33. if trailer is not None: # Make sure we're in a function
  34. context = execution_context.create_context(trailer)
  35. if _goes_to_param_name(param_name, context, name):
  36. values = _to_callables(context, trailer)
  37. args = TreeArguments.create_cached(
  38. execution_context.inference_state,
  39. context=context,
  40. argument_node=trailer.children[1],
  41. trailer=trailer,
  42. )
  43. for c in values:
  44. yield c, args
  45. def _goes_to_param_name(param_name, context, potential_name):
  46. if potential_name.type != 'name':
  47. return False
  48. from jedi.inference.names import TreeNameDefinition
  49. found = TreeNameDefinition(context, potential_name).goto()
  50. return any(param_name.parent_context == p.parent_context
  51. and param_name.start_pos == p.start_pos
  52. for p in found)
  53. def _to_callables(context, trailer):
  54. from jedi.inference.syntax_tree import infer_trailer
  55. atom_expr = trailer.parent
  56. index = atom_expr.children[0] == 'await'
  57. # Infer atom first
  58. values = context.infer_node(atom_expr.children[index])
  59. for trailer2 in atom_expr.children[index + 1:]:
  60. if trailer == trailer2:
  61. break
  62. values = infer_trailer(context, values, trailer2)
  63. return values
  64. def _remove_given_params(arguments, param_names):
  65. count = 0
  66. used_keys = set()
  67. for key, _ in arguments.unpack():
  68. if key is None:
  69. count += 1
  70. else:
  71. used_keys.add(key)
  72. for p in param_names:
  73. if count and p.maybe_positional_argument():
  74. count -= 1
  75. continue
  76. if p.string_name in used_keys and p.maybe_keyword_argument():
  77. continue
  78. yield p
  79. @to_list
  80. def process_params(param_names, star_count=3): # default means both * and **
  81. if param_names:
  82. if is_big_annoying_library(param_names[0].parent_context):
  83. # At first this feature can look innocent, but it does a lot of
  84. # type inference in some cases, so we just ditch it.
  85. yield from param_names
  86. return
  87. used_names = set()
  88. arg_callables = []
  89. kwarg_callables = []
  90. kw_only_names = []
  91. kwarg_names = []
  92. arg_names = []
  93. original_arg_name = None
  94. original_kwarg_name = None
  95. for p in param_names:
  96. kind = p.get_kind()
  97. if kind == Parameter.VAR_POSITIONAL:
  98. if star_count & 1:
  99. arg_callables = _iter_nodes_for_param(p)
  100. original_arg_name = p
  101. elif p.get_kind() == Parameter.VAR_KEYWORD:
  102. if star_count & 2:
  103. kwarg_callables = list(_iter_nodes_for_param(p))
  104. original_kwarg_name = p
  105. elif kind == Parameter.KEYWORD_ONLY:
  106. if star_count & 2:
  107. kw_only_names.append(p)
  108. elif kind == Parameter.POSITIONAL_ONLY:
  109. if star_count & 1:
  110. yield p
  111. else:
  112. if star_count == 1:
  113. yield ParamNameFixedKind(p, Parameter.POSITIONAL_ONLY)
  114. elif star_count == 2:
  115. kw_only_names.append(ParamNameFixedKind(p, Parameter.KEYWORD_ONLY))
  116. else:
  117. used_names.add(p.string_name)
  118. yield p
  119. # First process *args
  120. longest_param_names = ()
  121. found_arg_signature = False
  122. found_kwarg_signature = False
  123. for func_and_argument in arg_callables:
  124. func, arguments = func_and_argument
  125. new_star_count = star_count
  126. if func_and_argument in kwarg_callables:
  127. kwarg_callables.remove(func_and_argument)
  128. else:
  129. new_star_count = 1
  130. for signature in func.get_signatures():
  131. found_arg_signature = True
  132. if new_star_count == 3:
  133. found_kwarg_signature = True
  134. args_for_this_func = []
  135. for p in process_params(
  136. list(_remove_given_params(
  137. arguments,
  138. signature.get_param_names(resolve_stars=False)
  139. )), new_star_count):
  140. if p.get_kind() == Parameter.VAR_KEYWORD:
  141. kwarg_names.append(p)
  142. elif p.get_kind() == Parameter.VAR_POSITIONAL:
  143. arg_names.append(p)
  144. elif p.get_kind() == Parameter.KEYWORD_ONLY:
  145. kw_only_names.append(p)
  146. else:
  147. args_for_this_func.append(p)
  148. if len(args_for_this_func) > len(longest_param_names):
  149. longest_param_names = args_for_this_func
  150. for p in longest_param_names:
  151. if star_count == 1 and p.get_kind() != Parameter.VAR_POSITIONAL:
  152. yield ParamNameFixedKind(p, Parameter.POSITIONAL_ONLY)
  153. else:
  154. if p.get_kind() == Parameter.POSITIONAL_OR_KEYWORD:
  155. used_names.add(p.string_name)
  156. yield p
  157. if not found_arg_signature and original_arg_name is not None:
  158. yield original_arg_name
  159. elif arg_names:
  160. yield arg_names[0]
  161. # Then process **kwargs
  162. for func, arguments in kwarg_callables:
  163. for signature in func.get_signatures():
  164. found_kwarg_signature = True
  165. for p in process_params(
  166. list(_remove_given_params(
  167. arguments,
  168. signature.get_param_names(resolve_stars=False)
  169. )), star_count=2):
  170. if p.get_kind() == Parameter.VAR_KEYWORD:
  171. kwarg_names.append(p)
  172. elif p.get_kind() == Parameter.KEYWORD_ONLY:
  173. kw_only_names.append(p)
  174. for p in kw_only_names:
  175. if p.string_name in used_names:
  176. continue
  177. yield p
  178. used_names.add(p.string_name)
  179. if not found_kwarg_signature and original_kwarg_name is not None:
  180. yield original_kwarg_name
  181. elif kwarg_names:
  182. yield kwarg_names[0]
  183. class ParamNameFixedKind(ParamNameWrapper):
  184. def __init__(self, param_name, new_kind):
  185. super().__init__(param_name)
  186. self._new_kind = new_kind
  187. def get_kind(self):
  188. return self._new_kind