extract.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. from textwrap import dedent
  2. from parso import split_lines
  3. from jedi import debug
  4. from jedi.api.exceptions import RefactoringError
  5. from jedi.api.refactoring import Refactoring, EXPRESSION_PARTS
  6. from jedi.common import indent_block
  7. from jedi.parser_utils import function_is_classmethod, function_is_staticmethod
  8. _DEFINITION_SCOPES = ('suite', 'file_input')
  9. _VARIABLE_EXCTRACTABLE = EXPRESSION_PARTS + \
  10. ('atom testlist_star_expr testlist test lambdef lambdef_nocond '
  11. 'keyword name number string fstring').split()
  12. def extract_variable(inference_state, path, module_node, name, pos, until_pos):
  13. nodes = _find_nodes(module_node, pos, until_pos)
  14. debug.dbg('Extracting nodes: %s', nodes)
  15. is_expression, message = _is_expression_with_error(nodes)
  16. if not is_expression:
  17. raise RefactoringError(message)
  18. generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
  19. file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)}
  20. return Refactoring(inference_state, file_to_node_changes)
  21. def _is_expression_with_error(nodes):
  22. """
  23. Returns a tuple (is_expression, error_string).
  24. """
  25. if any(node.type == 'name' and node.is_definition() for node in nodes):
  26. return False, 'Cannot extract a name that defines something'
  27. if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
  28. return False, 'Cannot extract a "%s"' % nodes[0].type
  29. return True, ''
  30. def _find_nodes(module_node, pos, until_pos):
  31. """
  32. Looks up a module and tries to find the appropriate amount of nodes that
  33. are in there.
  34. """
  35. start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
  36. if until_pos is None:
  37. if start_node.type == 'operator':
  38. next_leaf = start_node.get_next_leaf()
  39. if next_leaf is not None and next_leaf.start_pos == pos:
  40. start_node = next_leaf
  41. if _is_not_extractable_syntax(start_node):
  42. start_node = start_node.parent
  43. if start_node.parent.type == 'trailer':
  44. start_node = start_node.parent.parent
  45. while start_node.parent.type in EXPRESSION_PARTS:
  46. start_node = start_node.parent
  47. nodes = [start_node]
  48. else:
  49. # Get the next leaf if we are at the end of a leaf
  50. if start_node.end_pos == pos:
  51. next_leaf = start_node.get_next_leaf()
  52. if next_leaf is not None:
  53. start_node = next_leaf
  54. # Some syntax is not exactable, just use its parent
  55. if _is_not_extractable_syntax(start_node):
  56. start_node = start_node.parent
  57. # Find the end
  58. end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True)
  59. if end_leaf.start_pos > until_pos:
  60. end_leaf = end_leaf.get_previous_leaf()
  61. if end_leaf is None:
  62. raise RefactoringError('Cannot extract anything from that')
  63. parent_node = start_node
  64. while parent_node.end_pos < end_leaf.end_pos:
  65. parent_node = parent_node.parent
  66. nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos)
  67. # If the user marks just a return statement, we return the expression
  68. # instead of the whole statement, because the user obviously wants to
  69. # extract that part.
  70. if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'):
  71. return [nodes[0].children[1]]
  72. return nodes
  73. def _replace(nodes, expression_replacement, extracted, pos,
  74. insert_before_leaf=None, remaining_prefix=None):
  75. # Now try to replace the nodes found with a variable and move the code
  76. # before the current statement.
  77. definition = _get_parent_definition(nodes[0])
  78. if insert_before_leaf is None:
  79. insert_before_leaf = definition.get_first_leaf()
  80. first_node_leaf = nodes[0].get_first_leaf()
  81. lines = split_lines(insert_before_leaf.prefix, keepends=True)
  82. if first_node_leaf is insert_before_leaf:
  83. if remaining_prefix is not None:
  84. # The remaining prefix has already been calculated.
  85. lines[:-1] = remaining_prefix
  86. lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n']
  87. extracted_prefix = ''.join(lines)
  88. replacement_dct = {}
  89. if first_node_leaf is insert_before_leaf:
  90. replacement_dct[nodes[0]] = extracted_prefix + expression_replacement
  91. else:
  92. if remaining_prefix is None:
  93. p = first_node_leaf.prefix
  94. else:
  95. p = remaining_prefix + _get_indentation(nodes[0])
  96. replacement_dct[nodes[0]] = p + expression_replacement
  97. replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value
  98. for node in nodes[1:]:
  99. replacement_dct[node] = ''
  100. return replacement_dct
  101. def _expression_nodes_to_string(nodes):
  102. return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes))
  103. def _suite_nodes_to_string(nodes, pos):
  104. n = nodes[0]
  105. prefix, part_of_code = _split_prefix_at(n.get_first_leaf(), pos[0] - 1)
  106. code = part_of_code + n.get_code(include_prefix=False) \
  107. + ''.join(n.get_code() for n in nodes[1:])
  108. return prefix, code
  109. def _split_prefix_at(leaf, until_line):
  110. """
  111. Returns a tuple of the leaf's prefix, split at the until_line
  112. position.
  113. """
  114. # second means the second returned part
  115. second_line_count = leaf.start_pos[0] - until_line
  116. lines = split_lines(leaf.prefix, keepends=True)
  117. return ''.join(lines[:-second_line_count]), ''.join(lines[-second_line_count:])
  118. def _get_indentation(node):
  119. return split_lines(node.get_first_leaf().prefix)[-1]
  120. def _get_parent_definition(node):
  121. """
  122. Returns the statement where a node is defined.
  123. """
  124. while node is not None:
  125. if node.parent.type in _DEFINITION_SCOPES:
  126. return node
  127. node = node.parent
  128. raise NotImplementedError('We should never even get here')
  129. def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
  130. """
  131. This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
  132. though it is not part of the expression.
  133. """
  134. typ = parent_node.type
  135. is_suite_part = typ in ('suite', 'file_input')
  136. if typ in EXPRESSION_PARTS or is_suite_part:
  137. nodes = parent_node.children
  138. for i, n in enumerate(nodes):
  139. if n.end_pos > pos:
  140. start_index = i
  141. if n.type == 'operator':
  142. start_index -= 1
  143. break
  144. for i, n in reversed(list(enumerate(nodes))):
  145. if n.start_pos < until_pos:
  146. end_index = i
  147. if n.type == 'operator':
  148. end_index += 1
  149. # Something like `not foo or bar` should not be cut after not
  150. for n2 in nodes[i:]:
  151. if _is_not_extractable_syntax(n2):
  152. end_index += 1
  153. else:
  154. break
  155. break
  156. nodes = nodes[start_index:end_index + 1]
  157. if not is_suite_part:
  158. nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
  159. nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
  160. return nodes
  161. return [parent_node]
  162. def _is_not_extractable_syntax(node):
  163. return node.type == 'operator' \
  164. or node.type == 'keyword' and node.value not in ('None', 'True', 'False')
  165. def extract_function(inference_state, path, module_context, name, pos, until_pos):
  166. nodes = _find_nodes(module_context.tree_node, pos, until_pos)
  167. assert len(nodes)
  168. is_expression, _ = _is_expression_with_error(nodes)
  169. context = module_context.create_context(nodes[0])
  170. is_bound_method = context.is_bound_method()
  171. params, return_variables = list(_find_inputs_and_outputs(module_context, context, nodes))
  172. # Find variables
  173. # Is a class method / method
  174. if context.is_module():
  175. insert_before_leaf = None # Leaf will be determined later
  176. else:
  177. node = _get_code_insertion_node(context.tree_node, is_bound_method)
  178. insert_before_leaf = node.get_first_leaf()
  179. if is_expression:
  180. code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
  181. remaining_prefix = None
  182. has_ending_return_stmt = False
  183. else:
  184. has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1])
  185. if not has_ending_return_stmt:
  186. # Find the actually used variables (of the defined ones). If none are
  187. # used (e.g. if the range covers the whole function), return the last
  188. # defined variable.
  189. return_variables = list(_find_needed_output_variables(
  190. context,
  191. nodes[0].parent,
  192. nodes[-1].end_pos,
  193. return_variables
  194. )) or [return_variables[-1]] if return_variables else []
  195. remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos)
  196. after_leaf = nodes[-1].get_next_leaf()
  197. first, second = _split_prefix_at(after_leaf, until_pos[0])
  198. code_block += first
  199. code_block = dedent(code_block)
  200. if not has_ending_return_stmt:
  201. output_var_str = ', '.join(return_variables)
  202. code_block += 'return ' + output_var_str + '\n'
  203. # Check if we have to raise RefactoringError
  204. _check_for_non_extractables(nodes[:-1] if has_ending_return_stmt else nodes)
  205. decorator = ''
  206. self_param = None
  207. if is_bound_method:
  208. if not function_is_staticmethod(context.tree_node):
  209. function_param_names = context.get_value().get_param_names()
  210. if len(function_param_names):
  211. self_param = function_param_names[0].string_name
  212. params = [p for p in params if p != self_param]
  213. if function_is_classmethod(context.tree_node):
  214. decorator = '@classmethod\n'
  215. else:
  216. code_block += '\n'
  217. function_code = '%sdef %s(%s):\n%s' % (
  218. decorator,
  219. name,
  220. ', '.join(params if self_param is None else [self_param] + params),
  221. indent_block(code_block)
  222. )
  223. function_call = '%s(%s)' % (
  224. ('' if self_param is None else self_param + '.') + name,
  225. ', '.join(params)
  226. )
  227. if is_expression:
  228. replacement = function_call
  229. else:
  230. if has_ending_return_stmt:
  231. replacement = 'return ' + function_call + '\n'
  232. else:
  233. replacement = output_var_str + ' = ' + function_call + '\n'
  234. replacement_dct = _replace(nodes, replacement, function_code, pos,
  235. insert_before_leaf, remaining_prefix)
  236. if not is_expression:
  237. replacement_dct[after_leaf] = second + after_leaf.value
  238. file_to_node_changes = {path: replacement_dct}
  239. return Refactoring(inference_state, file_to_node_changes)
  240. def _check_for_non_extractables(nodes):
  241. for n in nodes:
  242. try:
  243. children = n.children
  244. except AttributeError:
  245. if n.value == 'return':
  246. raise RefactoringError(
  247. 'Can only extract return statements if they are at the end.')
  248. if n.value == 'yield':
  249. raise RefactoringError('Cannot extract yield statements.')
  250. else:
  251. _check_for_non_extractables(children)
  252. def _is_name_input(module_context, names, first, last):
  253. for name in names:
  254. if name.api_type == 'param' or not name.parent_context.is_module():
  255. if name.get_root_context() is not module_context:
  256. return True
  257. if name.start_pos is None or not (first <= name.start_pos < last):
  258. return True
  259. return False
  260. def _find_inputs_and_outputs(module_context, context, nodes):
  261. first = nodes[0].start_pos
  262. last = nodes[-1].end_pos
  263. inputs = []
  264. outputs = []
  265. for name in _find_non_global_names(nodes):
  266. if name.is_definition():
  267. if name not in outputs:
  268. outputs.append(name.value)
  269. else:
  270. if name.value not in inputs:
  271. name_definitions = context.goto(name, name.start_pos)
  272. if not name_definitions \
  273. or _is_name_input(module_context, name_definitions, first, last):
  274. inputs.append(name.value)
  275. # Check if outputs are really needed:
  276. return inputs, outputs
  277. def _find_non_global_names(nodes):
  278. for node in nodes:
  279. try:
  280. children = node.children
  281. except AttributeError:
  282. if node.type == 'name':
  283. yield node
  284. else:
  285. # We only want to check foo in foo.bar
  286. if node.type == 'trailer' and node.children[0] == '.':
  287. continue
  288. yield from _find_non_global_names(children)
  289. def _get_code_insertion_node(node, is_bound_method):
  290. if not is_bound_method or function_is_staticmethod(node):
  291. while node.parent.type != 'file_input':
  292. node = node.parent
  293. while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
  294. node = node.parent
  295. return node
  296. def _find_needed_output_variables(context, search_node, at_least_pos, return_variables):
  297. """
  298. Searches everything after at_least_pos in a node and checks if any of the
  299. return_variables are used in there and returns those.
  300. """
  301. for node in search_node.children:
  302. if node.start_pos < at_least_pos:
  303. continue
  304. return_variables = set(return_variables)
  305. for name in _find_non_global_names([node]):
  306. if not name.is_definition() and name.value in return_variables:
  307. return_variables.remove(name.value)
  308. yield name.value
  309. def _is_node_ending_return_stmt(node):
  310. t = node.type
  311. if t == 'simple_stmt':
  312. return _is_node_ending_return_stmt(node.children[0])
  313. return t == 'return_stmt'