graph_deduplication.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. """
  2. This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
  3. Graph deduplication identifies identical subgraphs in the computational graph and merges them
  4. to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
  5. identifying structurally equivalent regions, and replacing them with a single shared implementation.
  6. This optimization is particularly effective for models with repeated patterns or similar computational
  7. structures across different parts of the network.
  8. """
  9. import logging
  10. import operator
  11. from collections import defaultdict, deque
  12. from collections.abc import Generator, Iterable
  13. from typing import Optional
  14. import torch
  15. import torch.fx
  16. from torch._dynamo import config
  17. from torch.multiprocessing.reductions import StorageWeakRef
  18. from torch.utils._ordered_set import OrderedSet
  19. from .graph_region_tracker import Node, Region
  20. from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
  21. # Represents an index into the region
  22. # to select a node and then
  23. # an index into that node's
  24. # flattened arguments
  25. UsageIndex = tuple[int, int]
  26. log = logging.getLogger(__name__)
  27. last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
  28. def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
  29. """
  30. This is the main entry point for applying the graph deduplication pass. \
  31. Deduplication occurs in two phases:
  32. 1. Subgraph creation:
  33. Subgraph creation works by taking one representative region from each region \
  34. group and creating a subgraph from it, which will then be used to replace all regions \
  35. in the group. This is implemented by first copying all nodes of the region to the new \
  36. subgraph and then finding all inputs which are not within the region and creating placeholders \
  37. for them. For the outputs, all regions in a region group need to be scanned to ensure the \
  38. largest set of outputs is found, and then an output node is created which returns \
  39. a tuple of all outputs.
  40. 2. Graph replacement:
  41. To replace each region with the extracted subgraph, the node index in the region \
  42. and argument index within the node's flattened args and kwargs are recorded once during \
  43. subgraph creation. This allows us to determine which (external to the region) nodes and \
  44. in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
  45. for each output, and all nodes in the region with external outputs are replaced by the proper \
  46. getitem node. Finally, all original nodes are erased (there should be no uses of these \
  47. left in the graph).
  48. The deduplication mutates the output_graph argument in place.
  49. Returns a mapping of nodes to their subgraph output replacement node to remap outputs
  50. when they are created in output_graph.
  51. """
  52. duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
  53. output_graph.graph
  54. )
  55. node_to_mutated_arg_positions = (
  56. output_graph.region_tracker.node_to_mutated_arg_positions
  57. )
  58. node_to_additional_deps = _populate_additional_deps(
  59. output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
  60. )
  61. sub_gms: dict[str, torch.fx.GraphModule] = {}
  62. for region_group in duplicated_region_groups:
  63. inds_with_external_users = _get_all_output_indices(region_group)
  64. region = region_group[0]
  65. (
  66. subgraph,
  67. external_node_usages,
  68. node_usage_to_tuple_elems,
  69. ind_to_tuple_spec,
  70. ) = _create_subgraph(region, inds_with_external_users)
  71. # Ignore regions with no args for now, could they possibly be evaluated at compile time?
  72. if not list(external_node_usages):
  73. continue
  74. sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
  75. subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
  76. sub_gms[subgraph_name] = sub_gm
  77. with output_graph.graph.inserting_before():
  78. get_subgraph_node = output_graph.graph.create_node(
  79. "get_attr", subgraph_name, (), {}
  80. )
  81. for region in region_group:
  82. _replace_region_with_subgraph(
  83. output_graph.graph,
  84. region,
  85. get_subgraph_node,
  86. external_node_usages,
  87. node_usage_to_tuple_elems,
  88. ind_to_tuple_spec,
  89. inds_with_external_users,
  90. subgraph_name,
  91. node_to_additional_deps,
  92. node_to_mutated_arg_positions,
  93. )
  94. # This is to expose the updated node_to_additional_deps to tests
  95. global last_node_to_additional_deps
  96. last_node_to_additional_deps = node_to_additional_deps
  97. _stable_topological_sort(
  98. output_graph.graph,
  99. node_to_additional_deps,
  100. )
  101. return sub_gms
  102. def _replace_region_with_subgraph(
  103. graph: torch.fx.Graph,
  104. region: Region,
  105. get_subgraph_node: Node,
  106. external_node_usages: Iterable[OrderedSet[UsageIndex]],
  107. node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
  108. ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
  109. inds_with_external_users: list[int],
  110. subgraph_name: str,
  111. node_to_additional_deps: dict[Node, OrderedSet[Node]],
  112. node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
  113. ) -> None:
  114. sub_args = []
  115. flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
  116. for usages in external_node_usages:
  117. usage = next(iter(usages))
  118. node_ind, usage_ind = usage
  119. node = region[node_ind]
  120. flattened_args_kwargs = _get_flat_args(node, {})
  121. for user_ind, node_usage_ind in usages:
  122. user = region[user_ind]
  123. if user in node_to_mutated_arg_positions:
  124. if node_usage_ind in node_to_mutated_arg_positions[user]:
  125. log.debug(
  126. "NYI: Failed to substitute region %s due to mutation", region
  127. )
  128. return
  129. if usage in node_usage_to_tuple_elems:
  130. tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
  131. flattened_getitem_nodes.update(tuple_elems)
  132. sub_args.extend(tuple_elems)
  133. else:
  134. sub_args.append(flattened_args_kwargs[usage_ind])
  135. # Input/Output aliasing not supported in HOPs today
  136. # Note: we should use the nodes in the original graph (the region here)
  137. # because we use the original traced example values for this check
  138. if _has_aliasing(
  139. region, sub_args, inds_with_external_users, flattened_getitem_nodes
  140. ):
  141. return
  142. invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
  143. invoke_subgraph_node = graph.create_node(
  144. "call_function",
  145. torch.ops.higher_order.invoke_subgraph,
  146. invoke_args, # type: ignore[arg-type]
  147. {},
  148. )
  149. ind = 0
  150. flattened_output_nodes: OrderedSet[Node] = OrderedSet()
  151. for external_user_ind in inds_with_external_users:
  152. node = region[external_user_ind]
  153. if _is_tuple_node(node):
  154. tuple_spec = ind_to_tuple_spec[external_user_ind]
  155. flattened_output_nodes.update(
  156. _replace_tuple_outputs(
  157. node, ind, tuple_spec, invoke_subgraph_node, graph
  158. )
  159. )
  160. ind += len(tuple_spec)
  161. else:
  162. subgraph_output = graph.create_node(
  163. "call_function", operator.getitem, (invoke_subgraph_node, ind), {}
  164. )
  165. node.replace_all_uses_with(subgraph_output, propagate_meta=True)
  166. ind += 1
  167. # Erase in reverse topological order
  168. for node in reversed(region):
  169. if node in flattened_getitem_nodes:
  170. # Don't erase these, since they will still be used
  171. continue
  172. if node not in flattened_output_nodes:
  173. graph.erase_node(node)
  174. # Remove any nodes with additional deps
  175. # This is safe; we've guaranteed that there is
  176. # no input mutation, so all additional deps
  177. # will be internal to the subgraph
  178. node_to_additional_deps.pop(node, None)
  179. for deps in node_to_additional_deps.values():
  180. try:
  181. deps.remove(node)
  182. deps.add(invoke_subgraph_node)
  183. except KeyError:
  184. pass
  185. if config.graph_deduplication_lint:
  186. print(_detect_cycles(graph, node_to_additional_deps))
  187. _stable_topological_sort(graph, node_to_additional_deps)
  188. graph.lint()
  189. def _get_external_inputs(
  190. region: Region,
  191. ) -> dict[Node, OrderedSet[UsageIndex]]:
  192. external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
  193. region_unique = set(region)
  194. for node_ind, node in enumerate(region):
  195. flattened_args_kwargs = _get_flat_args(node, {})
  196. for arg_ind, in_node in enumerate(flattened_args_kwargs):
  197. if isinstance(in_node, Node) and in_node not in region_unique:
  198. # in_node may occur in multiple nodes' flat_args
  199. # track this so we can check if the arg is mutated
  200. # Previously, we only needed to track one occurrence
  201. # to be able to map that node to a placeholder
  202. external_node_to_usages[in_node].add((node_ind, arg_ind))
  203. return external_node_to_usages
  204. def _get_all_output_indices(regions: list[Region]) -> list[int]:
  205. # Scan all regions to get the set of all possible output nodes indices in the region
  206. # perhaps we can record this information during region creation for more efficiency?
  207. inds_with_external_users: set[int] = set()
  208. for region in regions:
  209. _get_inds_with_external_users(region, inds_with_external_users)
  210. return sorted(inds_with_external_users)
  211. def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None:
  212. for ind, node in enumerate(region):
  213. for user in node.users:
  214. if user not in region:
  215. if ind not in inds_unique:
  216. inds_unique.add(ind)
  217. def _create_subgraph(
  218. region: Region,
  219. inds_with_external_users: list[int],
  220. ) -> tuple[
  221. torch.fx.Graph,
  222. list[OrderedSet[UsageIndex]],
  223. dict[UsageIndex, OrderedSet[int]],
  224. dict[int, dict[tuple[int, ...], int]],
  225. ]:
  226. subgraph: torch.fx.Graph = torch.fx.Graph()
  227. external_input_to_usages = _get_external_inputs(region)
  228. external_node_usages = list[OrderedSet[UsageIndex]]()
  229. region_to_subgraph_node = {}
  230. flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
  231. node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
  232. for node, usage_indices in external_input_to_usages.items():
  233. # We don't handle tuples as inputs today
  234. if _is_tuple_node(node):
  235. # If a node is a tuple we will possibly create multiple placeholders for them
  236. # and track which nodes we won't copy into the subgraph because they are flattened away
  237. # Later, when replacing each region with this subgraph, we will create a getitem node
  238. # externally which will perform the flattening on the outer nodes.
  239. flattened_node_indices = _get_flattened_node_indices(node, region)
  240. for ind in flattened_node_indices:
  241. placeholder = subgraph.placeholder(
  242. f"supgraph_input_{node.name}_flattened_{ind}"
  243. )
  244. region_to_subgraph_node[region[ind]] = placeholder
  245. flattened_getitem_nodes.add(region[ind])
  246. node_usage_to_tuple_elems[next(iter(usage_indices))] = (
  247. flattened_node_indices
  248. )
  249. else:
  250. placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
  251. region_to_subgraph_node[node] = placeholder
  252. external_node_usages.append(usage_indices)
  253. def map_arg(node: Node) -> Node:
  254. if node in region_to_subgraph_node:
  255. return region_to_subgraph_node[node]
  256. else:
  257. return node
  258. def copy_to_subgraph(node: Node) -> Node:
  259. subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
  260. region_to_subgraph_node[node] = subgraph_node
  261. return subgraph_node
  262. output_list = []
  263. ind_to_tuple_spec = {}
  264. for ind, node in enumerate(region):
  265. if node not in flattened_getitem_nodes:
  266. subgraph_node = copy_to_subgraph(node)
  267. if ind in inds_with_external_users:
  268. # flatten tuple outputs by generating a getitem node tree
  269. if _is_tuple_node(node):
  270. getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
  271. node, subgraph_node, subgraph
  272. )
  273. output_list.extend(getitem_nodes)
  274. else:
  275. output_list.append(subgraph_node)
  276. subgraph.output(tuple(output_list))
  277. return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
  278. def _stable_topological_sort_impl(
  279. graph: torch.fx.Graph,
  280. node_to_additional_deps: dict[Node, OrderedSet[Node]],
  281. do_sort: bool = True,
  282. ) -> bool:
  283. # Nodes are in exactly one of these four collections:
  284. # - Nodes in `pending` are waiting to be processed (in reverse order):
  285. pending = list(reversed(graph.nodes))
  286. # - Nodes in `ready` have been processed and are already in the correct
  287. # order.
  288. ready = OrderedSet[Node]()
  289. # - `waiting` is a mapping from a dependency to nodes which depend on that
  290. # dependency.
  291. waiting = defaultdict(list)
  292. # - `outputs` are always at the end of the graph
  293. outputs = OrderedSet[Node]()
  294. # The cursor indicates the last processed node so we can add new nodes
  295. # after it.
  296. cursor = None
  297. while pending:
  298. node = pending.pop()
  299. if node.target == "output":
  300. outputs.add(node)
  301. assert not node.users, "output nodes should have no users"
  302. continue
  303. waiting_for = [
  304. x
  305. for x in _get_flat_args_unique(node, node_to_additional_deps)
  306. if x not in ready
  307. ]
  308. if waiting_for:
  309. # We have unprocessed input nodes. Might as well wait for the last
  310. # arg so an already sorted list will only recheck this node once.
  311. waiting[waiting_for[-1]].append(node)
  312. else:
  313. ready.add(node)
  314. if cursor and cursor.next is not node and do_sort:
  315. cursor.append(node)
  316. cursor = node
  317. # Mark the nodes that have been waiting for this node to finish as
  318. # ready to check again.
  319. pending.extend(reversed(waiting.pop(node, ())))
  320. ready.update(outputs)
  321. return not waiting and len(ready) == len(graph.nodes)
  322. def _stable_topological_sort(
  323. graph: torch.fx.Graph,
  324. node_to_additional_deps: dict[Node, OrderedSet[Node]],
  325. ) -> None:
  326. assert _stable_topological_sort_impl(graph, node_to_additional_deps)
  327. def _has_cycle(
  328. graph: torch.fx.Graph,
  329. node_to_additional_deps: dict[Node, OrderedSet[Node]],
  330. ) -> bool:
  331. return not _stable_topological_sort_impl(
  332. graph, node_to_additional_deps, do_sort=False
  333. )
  334. def _populate_additional_deps(
  335. graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
  336. ) -> dict[Node, OrderedSet[Node]]:
  337. node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
  338. _add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
  339. _add_global_state_dependencies(graph, node_to_additional_deps)
  340. return node_to_additional_deps
  341. def _add_global_state_dependencies(
  342. graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  343. ) -> None:
  344. import torch.amp
  345. all_nodes = list(graph.nodes)
  346. # These are targets of the nodes which need to stay in the same relative place in the graph
  347. global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
  348. all_nodes_dep_on: list[Node] = []
  349. def prev_cur_nodes(
  350. all_nodes: list[Node],
  351. ) -> Generator[tuple[list[Node], Node], None, None]:
  352. prev_nodes: list[Node] = []
  353. next_nodes = list(reversed(all_nodes))
  354. while next_nodes:
  355. cur_node = next_nodes.pop()
  356. yield prev_nodes, cur_node
  357. prev_nodes.append(cur_node)
  358. for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
  359. args_unique = _get_flat_args_unique(cur_node, {})
  360. new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
  361. if new_deps:
  362. additional_deps = node_to_additional_deps[cur_node]
  363. additional_deps.update(new_deps)
  364. if cur_node.target in global_state_targets:
  365. additional_deps = node_to_additional_deps[cur_node]
  366. additional_deps.update(n for n in prev_nodes if n not in args_unique)
  367. all_nodes_dep_on.append(cur_node)
  368. def _add_mutation_dependencies(
  369. node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
  370. node_to_additional_deps: dict[Node, OrderedSet[Node]],
  371. ) -> None:
  372. for node, indices in node_to_mutated_arg_positions.items():
  373. flat_args_kwargs = _get_flat_args(node, {})
  374. # for all mutated args,
  375. # add dependency on usages which occur after node to ensure
  376. # node will always be ordered before them
  377. # also add node as a dependency on usages which
  378. # occur before node to ensure node is ordered after them
  379. for index in indices:
  380. mutated_arg = flat_args_kwargs[index]
  381. for user in mutated_arg.users:
  382. if user is node:
  383. continue
  384. elif user < node:
  385. node_to_additional_deps[node].add(user)
  386. elif user > node:
  387. node_to_additional_deps[user].add(node)
  388. def _has_aliasing(
  389. region: Region,
  390. inputs: list[Node],
  391. inds_with_external_users: list[int],
  392. flattened_getitem_nodes: OrderedSet[Node],
  393. ) -> bool:
  394. input_storages: dict[StorageWeakRef, Node] = dict()
  395. for node in inputs:
  396. if node in flattened_getitem_nodes:
  397. continue
  398. example_value = node.meta["example_value"]
  399. if isinstance(example_value, torch.Tensor):
  400. storage = StorageWeakRef(example_value._typed_storage())
  401. if storage in input_storages:
  402. # input-input aliasing
  403. log.debug(
  404. "NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
  405. region,
  406. input_storages[storage],
  407. node,
  408. )
  409. return True
  410. input_storages[storage] = node
  411. output_storages: dict[StorageWeakRef, Node] = dict()
  412. for i in inds_with_external_users:
  413. out_node = region[i]
  414. if out_node in flattened_getitem_nodes:
  415. continue
  416. if out_node:
  417. example_value = out_node.meta["example_value"]
  418. assert not isinstance(example_value, list)
  419. if isinstance(example_value, torch.Tensor):
  420. storage = StorageWeakRef(example_value._typed_storage())
  421. if storage in output_storages:
  422. # output-output aliasing
  423. log.debug(
  424. "NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
  425. region,
  426. output_storages[storage],
  427. out_node,
  428. )
  429. return True
  430. output_storages[storage] = out_node
  431. intersected_storages = input_storages.keys() & output_storages.keys()
  432. if len(intersected_storages) > 0:
  433. # input-output aliasing
  434. aliased = [
  435. (input_storages[s], output_storages[s]) for s in intersected_storages
  436. ]
  437. aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
  438. log.debug(
  439. "NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
  440. region,
  441. aliased,
  442. )
  443. return True
  444. return False
  445. def _is_tuple_node(node: Node) -> bool:
  446. return isinstance(node.meta["example_value"], tuple)
  447. def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
  448. for user in node.users:
  449. if user.target is operator.getitem and isinstance(user.args[1], int):
  450. yield user
  451. def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
  452. """Returns an ordered set of indices, each representing a node in the region which will be flattened"""
  453. flattened_node_to_ind = {n: i for i, n in enumerate(region)}
  454. node_indices: OrderedSet[int] = OrderedSet()
  455. queue = deque(_get_children_getitems(node))
  456. while queue:
  457. cur_node = queue.popleft()
  458. if any(user in region for user in cur_node.users):
  459. node_indices.add(flattened_node_to_ind[cur_node])
  460. for child in _get_children_getitems(cur_node):
  461. queue.append(child)
  462. return node_indices
  463. def _create_getitem_nodes(
  464. node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
  465. ) -> tuple[list[Node], dict[tuple[int, ...], int]]:
  466. tup = node.meta["example_value"]
  467. assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
  468. getitem_nodes: list[Node] = []
  469. queue = deque([(e, (i,), subgraph_tuple_node) for i, e in enumerate(tup)])
  470. path_to_output_index = {}
  471. while queue:
  472. cur_elem, path, parent = queue.popleft()
  473. with subgraph.inserting_after(parent):
  474. new_getitem_node = subgraph.create_node(
  475. "call_function", operator.getitem, (parent, path[-1]), {}
  476. )
  477. new_getitem_node.meta["example_value"] = cur_elem
  478. path_to_output_index[path] = len(getitem_nodes)
  479. getitem_nodes.append(new_getitem_node)
  480. if isinstance(cur_elem, tuple):
  481. queue.extend(
  482. [(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)] # type: ignore[arg-type,misc]
  483. )
  484. return getitem_nodes, path_to_output_index # type: ignore[return-value]
  485. def _replace_tuple_outputs(
  486. node: Node,
  487. output_index: int,
  488. tuple_spec: dict[tuple[int, ...], int],
  489. invoke_subgraph_node: Node,
  490. graph: torch.fx.Graph,
  491. ) -> OrderedSet[Node]:
  492. assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
  493. queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
  494. erased_nodes: OrderedSet[Node] = OrderedSet()
  495. while queue:
  496. cur_node, path = queue.pop()
  497. for c in _get_children_getitems(cur_node):
  498. queue.append((c, path + (c.args[1],))) # type: ignore[return-value, arg-type]
  499. with graph.inserting_after(invoke_subgraph_node):
  500. subgraph_output = graph.create_node(
  501. "call_function",
  502. operator.getitem,
  503. (invoke_subgraph_node, output_index + tuple_spec[path]), # type: ignore[index]
  504. {},
  505. )
  506. cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
  507. graph.erase_node(cur_node)
  508. erased_nodes.add(cur_node)
  509. graph.erase_node(node)
  510. erased_nodes.add(node)
  511. return erased_nodes