steinertree.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from itertools import chain
  2. import networkx as nx
  3. from networkx.utils import not_implemented_for, pairwise
  4. __all__ = ["metric_closure", "steiner_tree"]
  5. @not_implemented_for("directed")
  6. @nx._dispatchable(edge_attrs="weight", returns_graph=True)
  7. def metric_closure(G, weight="weight"):
  8. """Return the metric closure of a graph.
  9. The metric closure of a graph *G* is the complete graph in which each edge
  10. is weighted by the shortest path distance between the nodes in *G* .
  11. Parameters
  12. ----------
  13. G : NetworkX graph
  14. Returns
  15. -------
  16. NetworkX graph
  17. Metric closure of the graph `G`.
  18. Notes
  19. -----
  20. .. deprecated:: 3.6
  21. `metric_closure` is deprecated and will be removed in NetworkX 3.8.
  22. Use :func:`networkx.all_pairs_shortest_path_length` instead.
  23. """
  24. import warnings
  25. warnings.warn(
  26. "metric_closure is deprecated and will be removed in NetworkX 3.8.\n"
  27. "Use nx.all_pairs_shortest_path_length instead.",
  28. category=DeprecationWarning,
  29. stacklevel=5,
  30. )
  31. M = nx.Graph()
  32. Gnodes = set(G)
  33. # check for connected graph while processing first node
  34. all_paths_iter = nx.all_pairs_dijkstra(G, weight=weight)
  35. u, (distance, path) = next(all_paths_iter)
  36. if len(G) != len(distance):
  37. msg = "G is not a connected graph. metric_closure is not defined."
  38. raise nx.NetworkXError(msg)
  39. Gnodes.remove(u)
  40. for v in Gnodes:
  41. M.add_edge(u, v, distance=distance[v], path=path[v])
  42. # first node done -- now process the rest
  43. for u, (distance, path) in all_paths_iter:
  44. Gnodes.remove(u)
  45. for v in Gnodes:
  46. M.add_edge(u, v, distance=distance[v], path=path[v])
  47. return M
  48. def _mehlhorn_steiner_tree(G, terminal_nodes, weight):
  49. distances, paths = nx.multi_source_dijkstra(G, terminal_nodes, weight=weight)
  50. d_1 = {}
  51. s = {}
  52. for v in G.nodes():
  53. s[v] = paths[v][0]
  54. d_1[(v, s[v])] = distances[v]
  55. # G1-G4 names match those from the Mehlhorn 1988 paper.
  56. G_1_prime = nx.Graph()
  57. # iterate over all edges to complete d1
  58. for u, v, data in G.edges(data=True):
  59. su, sv = s[u], s[v]
  60. weight_here = d_1[(u, su)] + data.get(weight, 1) + d_1[(v, sv)]
  61. if not G_1_prime.has_edge(su, sv):
  62. G_1_prime.add_edge(su, sv, weight_d1=weight_here)
  63. else:
  64. new_weight = min(weight_here, G_1_prime[su][sv]["weight_d1"])
  65. G_1_prime.add_edge(su, sv, weight_d1=new_weight)
  66. G_2 = nx.minimum_spanning_edges(G_1_prime, data=True, weight="weight_d1")
  67. G_3 = nx.Graph()
  68. for u, v, _ in G_2:
  69. path = nx.shortest_path(G, u, v, weight=weight)
  70. for n1, n2 in pairwise(path):
  71. G_3.add_edge(n1, n2, weight=G[n1][n2].get(weight, 1))
  72. G_3_mst = list(nx.minimum_spanning_edges(G_3, data=False, weight=weight))
  73. if G.is_multigraph():
  74. G_3_mst = (
  75. (u, v, min(G[u][v], key=lambda k: G[u][v][k].get(weight, 1)))
  76. for u, v in G_3_mst
  77. )
  78. G_4 = G.edge_subgraph(G_3_mst).copy()
  79. _remove_nonterminal_leaves(G_4, terminal_nodes)
  80. return G_4.edges()
  81. def _kou_steiner_tree(G, terminal_nodes, weight):
  82. # Compute the metric closure only for terminal nodes
  83. # Create a complete graph H from the metric edges
  84. H = nx.Graph()
  85. unvisited_terminals = set(terminal_nodes)
  86. # check for connected graph while processing first node
  87. u = unvisited_terminals.pop()
  88. distances, paths = nx.single_source_dijkstra(G, source=u, weight=weight)
  89. if len(G) != len(distances):
  90. msg = "G is not a connected graph."
  91. raise nx.NetworkXError(msg)
  92. for v in unvisited_terminals:
  93. H.add_edge(u, v, distance=distances[v], path=paths[v])
  94. # first node done -- now process the rest
  95. for u in unvisited_terminals.copy():
  96. distances, paths = nx.single_source_dijkstra(G, source=u, weight=weight)
  97. unvisited_terminals.remove(u)
  98. for v in unvisited_terminals:
  99. H.add_edge(u, v, distance=distances[v], path=paths[v])
  100. # Use the 'distance' attribute of each edge provided by H.
  101. mst_edges = nx.minimum_spanning_edges(H, weight="distance", data=True)
  102. # Create an iterator over each edge in each shortest path; repeats are okay
  103. mst_all_edges = chain.from_iterable(pairwise(d["path"]) for u, v, d in mst_edges)
  104. if G.is_multigraph():
  105. mst_all_edges = (
  106. (u, v, min(G[u][v], key=lambda k: G[u][v][k].get(weight, 1)))
  107. for u, v in mst_all_edges
  108. )
  109. # Find the MST again, over this new set of edges
  110. G_S = G.edge_subgraph(mst_all_edges)
  111. T_S = nx.minimum_spanning_edges(G_S, weight="weight", data=False)
  112. # Leaf nodes that are not terminal might still remain; remove them here
  113. T_H = G.edge_subgraph(T_S).copy()
  114. _remove_nonterminal_leaves(T_H, terminal_nodes)
  115. return T_H.edges()
  116. def _remove_nonterminal_leaves(G, terminals):
  117. terminal_set = set(terminals)
  118. leaves = {n for n in G if len(set(G[n]) - {n}) == 1}
  119. nonterminal_leaves = leaves - terminal_set
  120. while nonterminal_leaves:
  121. # Removing a node may create new non-terminal leaves, so we limit
  122. # search for candidate non-terminal nodes to neighbors of current
  123. # non-terminal nodes
  124. candidate_leaves = set.union(*(set(G[n]) for n in nonterminal_leaves))
  125. candidate_leaves -= nonterminal_leaves | terminal_set
  126. # Remove current set of non-terminal nodes
  127. G.remove_nodes_from(nonterminal_leaves)
  128. # Find any new non-terminal nodes from the set of candidates
  129. leaves = {n for n in candidate_leaves if len(set(G[n]) - {n}) == 1}
  130. nonterminal_leaves = leaves - terminal_set
  131. ALGORITHMS = {
  132. "kou": _kou_steiner_tree,
  133. "mehlhorn": _mehlhorn_steiner_tree,
  134. }
  135. @not_implemented_for("directed")
  136. @nx._dispatchable(preserve_all_attrs=True, returns_graph=True)
  137. def steiner_tree(G, terminal_nodes, weight="weight", method=None):
  138. r"""Return an approximation to the minimum Steiner tree of a graph.
  139. The minimum Steiner tree of `G` w.r.t a set of `terminal_nodes` (also *S*)
  140. is a tree within `G` that spans those nodes and has minimum size (sum of
  141. edge weights) among all such trees.
  142. The approximation algorithm is specified with the `method` keyword
  143. argument. All three available algorithms produce a tree whose weight is
  144. within a ``(2 - (2 / l))`` factor of the weight of the optimal Steiner tree,
  145. where ``l`` is the minimum number of leaf nodes across all possible Steiner
  146. trees.
  147. * ``"kou"`` [2]_ (runtime $O(|S| |V|^2)$) computes the minimum spanning tree of
  148. the subgraph of the metric closure of *G* induced by the terminal nodes,
  149. where the metric closure of *G* is the complete graph in which each edge is
  150. weighted by the shortest path distance between the nodes in *G*.
  151. * ``"mehlhorn"`` [3]_ (runtime $O(|E|+|V|\log|V|)$) modifies Kou et al.'s
  152. algorithm, beginning by finding the closest terminal node for each
  153. non-terminal. This data is used to create a complete graph containing only
  154. the terminal nodes, in which edge is weighted with the shortest path
  155. distance between them. The algorithm then proceeds in the same way as Kou
  156. et al..
  157. Parameters
  158. ----------
  159. G : NetworkX graph
  160. terminal_nodes : list
  161. A list of terminal nodes for which minimum steiner tree is
  162. to be found.
  163. weight : string (default = 'weight')
  164. Use the edge attribute specified by this string as the edge weight.
  165. Any edge attribute not present defaults to 1.
  166. method : string, optional (default = 'mehlhorn')
  167. The algorithm to use to approximate the Steiner tree.
  168. Supported options: 'kou', 'mehlhorn'.
  169. Other inputs produce a ValueError.
  170. Returns
  171. -------
  172. NetworkX graph
  173. Approximation to the minimum steiner tree of `G` induced by
  174. `terminal_nodes` .
  175. Raises
  176. ------
  177. NetworkXNotImplemented
  178. If `G` is directed.
  179. ValueError
  180. If the specified `method` is not supported.
  181. Notes
  182. -----
  183. For multigraphs, the edge between two nodes with minimum weight is the
  184. edge put into the Steiner tree.
  185. References
  186. ----------
  187. .. [1] Steiner_tree_problem on Wikipedia.
  188. https://en.wikipedia.org/wiki/Steiner_tree_problem
  189. .. [2] Kou, L., G. Markowsky, and L. Berman. 1981.
  190. ‘A Fast Algorithm for Steiner Trees’.
  191. Acta Informatica 15 (2): 141–45.
  192. https://doi.org/10.1007/BF00288961.
  193. .. [3] Mehlhorn, Kurt. 1988.
  194. ‘A Faster Approximation Algorithm for the Steiner Problem in Graphs’.
  195. Information Processing Letters 27 (3): 125–28.
  196. https://doi.org/10.1016/0020-0190(88)90066-X.
  197. """
  198. if method is None:
  199. method = "mehlhorn"
  200. try:
  201. algo = ALGORITHMS[method]
  202. except KeyError as e:
  203. raise ValueError(f"{method} is not a valid choice for an algorithm.") from e
  204. edges = algo(G, terminal_nodes, weight)
  205. # For multigraph we should add the minimal weight edge keys
  206. if G.is_multigraph():
  207. edges = (
  208. (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) for u, v in edges
  209. )
  210. T = G.edge_subgraph(edges)
  211. return T