branchings.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042
  1. """
  2. Algorithms for finding optimum branchings and spanning arborescences.
  3. This implementation is based on:
  4. J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967),
  5. 233–240. URL: http://archive.org/details/jresv71Bn4p233
  6. """
  7. # TODO: Implement method from Gabow, Galil, Spence and Tarjan:
  8. #
  9. # @article{
  10. # year={1986},
  11. # issn={0209-9683},
  12. # journal={Combinatorica},
  13. # volume={6},
  14. # number={2},
  15. # doi={10.1007/BF02579168},
  16. # title={Efficient algorithms for finding minimum spanning trees in
  17. # undirected and directed graphs},
  18. # url={https://doi.org/10.1007/BF02579168},
  19. # publisher={Springer-Verlag},
  20. # keywords={68 B 15; 68 C 05},
  21. # author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan,
  22. # Robert E.},
  23. # pages={109-122},
  24. # language={English}
  25. # }
  26. import string
  27. from dataclasses import dataclass, field
  28. from operator import itemgetter
  29. from queue import PriorityQueue
  30. import networkx as nx
  31. from networkx.utils import py_random_state
  32. from .recognition import is_arborescence, is_branching
  33. __all__ = [
  34. "branching_weight",
  35. "greedy_branching",
  36. "maximum_branching",
  37. "minimum_branching",
  38. "minimal_branching",
  39. "maximum_spanning_arborescence",
  40. "minimum_spanning_arborescence",
  41. "ArborescenceIterator",
  42. ]
  43. KINDS = {"max", "min"}
  44. STYLES = {
  45. "branching": "branching",
  46. "arborescence": "arborescence",
  47. "spanning arborescence": "arborescence",
  48. }
  49. INF = float("inf")
  50. @py_random_state(1)
  51. def random_string(L=15, seed=None):
  52. return "".join([seed.choice(string.ascii_letters) for n in range(L)])
  53. def _min_weight(weight):
  54. return -weight
  55. def _max_weight(weight):
  56. return weight
  57. @nx._dispatchable(edge_attrs={"attr": "default"})
  58. def branching_weight(G, attr="weight", default=1):
  59. """
  60. Returns the total weight of a branching.
  61. You must access this function through the networkx.algorithms.tree module.
  62. Parameters
  63. ----------
  64. G : DiGraph
  65. The directed graph.
  66. attr : str
  67. The attribute to use as weights. If None, then each edge will be
  68. treated equally with a weight of 1.
  69. default : float
  70. When `attr` is not None, then if an edge does not have that attribute,
  71. `default` specifies what value it should take.
  72. Returns
  73. -------
  74. weight: int or float
  75. The total weight of the branching.
  76. Examples
  77. --------
  78. >>> G = nx.DiGraph()
  79. >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)])
  80. >>> nx.tree.branching_weight(G)
  81. 11
  82. """
  83. return sum(edge[2].get(attr, default) for edge in G.edges(data=True))
  84. @py_random_state(4)
  85. @nx._dispatchable(edge_attrs={"attr": "default"}, returns_graph=True)
  86. def greedy_branching(G, attr="weight", default=1, kind="max", seed=None):
  87. """
  88. Returns a branching obtained through a greedy algorithm.
  89. This algorithm is wrong, and cannot give a proper optimal branching.
  90. However, we include it for pedagogical reasons, as it can be helpful to
  91. see what its outputs are.
  92. The output is a branching, and possibly, a spanning arborescence. However,
  93. it is not guaranteed to be optimal in either case.
  94. Parameters
  95. ----------
  96. G : DiGraph
  97. The directed graph to scan.
  98. attr : str
  99. The attribute to use as weights. If None, then each edge will be
  100. treated equally with a weight of 1.
  101. default : float
  102. When `attr` is not None, then if an edge does not have that attribute,
  103. `default` specifies what value it should take.
  104. kind : str
  105. The type of optimum to search for: 'min' or 'max' greedy branching.
  106. seed : integer, random_state, or None (default)
  107. Indicator of random number generation state.
  108. See :ref:`Randomness<randomness>`.
  109. Returns
  110. -------
  111. B : directed graph
  112. The greedily obtained branching.
  113. """
  114. if kind not in KINDS:
  115. raise nx.NetworkXException("Unknown value for `kind`.")
  116. if kind == "min":
  117. reverse = False
  118. else:
  119. reverse = True
  120. if attr is None:
  121. # Generate a random string the graph probably won't have.
  122. attr = random_string(seed=seed)
  123. edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)]
  124. # We sort by weight, but also by nodes to normalize behavior across runs.
  125. try:
  126. edges.sort(key=itemgetter(2, 0, 1), reverse=reverse)
  127. except TypeError:
  128. # This will fail in Python 3.x if the nodes are of varying types.
  129. # In that case, we use the arbitrary order.
  130. edges.sort(key=itemgetter(2), reverse=reverse)
  131. # The branching begins with a forest of no edges.
  132. B = nx.DiGraph()
  133. B.add_nodes_from(G)
  134. # Now we add edges greedily so long we maintain the branching.
  135. uf = nx.utils.UnionFind()
  136. for i, (u, v, w) in enumerate(edges):
  137. if uf[u] == uf[v]:
  138. # Adding this edge would form a directed cycle.
  139. continue
  140. elif B.in_degree(v) == 1:
  141. # The edge would increase the degree to be greater than one.
  142. continue
  143. else:
  144. # If attr was None, then don't insert weights...
  145. data = {}
  146. if attr is not None:
  147. data[attr] = w
  148. B.add_edge(u, v, **data)
  149. uf.union(u, v)
  150. return B
  151. @nx._dispatchable(preserve_edge_attrs=True, returns_graph=True)
  152. def maximum_branching(
  153. G,
  154. attr="weight",
  155. default=1,
  156. preserve_attrs=False,
  157. partition=None,
  158. ):
  159. #######################################
  160. ### Data Structure Helper Functions ###
  161. #######################################
  162. def edmonds_add_edge(G, edge_index, u, v, key, **d):
  163. """
  164. Adds an edge to `G` while also updating the edge index.
  165. This algorithm requires the use of an external dictionary to track
  166. the edge keys since it is possible that the source or destination
  167. node of an edge will be changed and the default key-handling
  168. capabilities of the MultiDiGraph class do not account for this.
  169. Parameters
  170. ----------
  171. G : MultiDiGraph
  172. The graph to insert an edge into.
  173. edge_index : dict
  174. A mapping from integers to the edges of the graph.
  175. u : node
  176. The source node of the new edge.
  177. v : node
  178. The destination node of the new edge.
  179. key : int
  180. The key to use from `edge_index`.
  181. d : keyword arguments, optional
  182. Other attributes to store on the new edge.
  183. """
  184. if key in edge_index:
  185. uu, vv, _ = edge_index[key]
  186. if (u != uu) or (v != vv):
  187. raise Exception(f"Key {key!r} is already in use.")
  188. G.add_edge(u, v, key, **d)
  189. edge_index[key] = (u, v, G.succ[u][v][key])
  190. def edmonds_remove_node(G, edge_index, n):
  191. """
  192. Remove a node from the graph, updating the edge index to match.
  193. Parameters
  194. ----------
  195. G : MultiDiGraph
  196. The graph to remove an edge from.
  197. edge_index : dict
  198. A mapping from integers to the edges of the graph.
  199. n : node
  200. The node to remove from `G`.
  201. """
  202. keys = set()
  203. for keydict in G.pred[n].values():
  204. keys.update(keydict)
  205. for keydict in G.succ[n].values():
  206. keys.update(keydict)
  207. for key in keys:
  208. del edge_index[key]
  209. G.remove_node(n)
  210. #######################
  211. ### Algorithm Setup ###
  212. #######################
  213. # Pick an attribute name that the original graph is unlikly to have
  214. candidate_attr = "edmonds' secret candidate attribute"
  215. new_node_base_name = "edmonds new node base name "
  216. G_original = G
  217. G = nx.MultiDiGraph()
  218. G.__networkx_cache__ = None # Disable caching
  219. # A dict to reliably track mutations to the edges using the key of the edge.
  220. G_edge_index = {}
  221. # Each edge is given an arbitrary numerical key
  222. for key, (u, v, data) in enumerate(G_original.edges(data=True)):
  223. d = {attr: data.get(attr, default)}
  224. if data.get(partition) is not None:
  225. d[partition] = data.get(partition)
  226. if preserve_attrs:
  227. for d_k, d_v in data.items():
  228. if d_k != attr:
  229. d[d_k] = d_v
  230. edmonds_add_edge(G, G_edge_index, u, v, key, **d)
  231. level = 0 # Stores the number of contracted nodes
  232. # These are the buckets from the paper.
  233. #
  234. # In the paper, G^i are modified versions of the original graph.
  235. # D^i and E^i are the nodes and edges of the maximal edges that are
  236. # consistent with G^i. In this implementation, D^i and E^i are stored
  237. # together as the graph B^i. We will have strictly more B^i then the
  238. # paper will have.
  239. #
  240. # Note that the data in graphs and branchings are tuples with the graph as
  241. # the first element and the edge index as the second.
  242. B = nx.MultiDiGraph()
  243. B_edge_index = {}
  244. graphs = [] # G^i list
  245. branchings = [] # B^i list
  246. selected_nodes = set() # D^i bucket
  247. uf = nx.utils.UnionFind()
  248. # A list of lists of edge indices. Each list is a circuit for graph G^i.
  249. # Note the edge list is not required to be a circuit in G^0.
  250. circuits = []
  251. # Stores the index of the minimum edge in the circuit found in G^i and B^i.
  252. # The ordering of the edges seems to preserver the weight ordering from
  253. # G^0. So even if the circuit does not form a circuit in G^0, it is still
  254. # true that the minimum edges in circuit G^0 (despite their weights being
  255. # different)
  256. minedge_circuit = []
  257. ###########################
  258. ### Algorithm Structure ###
  259. ###########################
  260. # Each step listed in the algorithm is an inner function. Thus, the overall
  261. # loop structure is:
  262. #
  263. # while True:
  264. # step_I1()
  265. # if cycle detected:
  266. # step_I2()
  267. # elif every node of G is in D and E is a branching:
  268. # break
  269. ##################################
  270. ### Algorithm Helper Functions ###
  271. ##################################
  272. def edmonds_find_desired_edge(v):
  273. """
  274. Find the edge directed towards v with maximal weight.
  275. If an edge partition exists in this graph, return the included
  276. edge if it exists and never return any excluded edge.
  277. Note: There can only be one included edge for each vertex otherwise
  278. the edge partition is empty.
  279. Parameters
  280. ----------
  281. v : node
  282. The node to search for the maximal weight incoming edge.
  283. """
  284. edge = None
  285. max_weight = -INF
  286. for u, _, key, data in G.in_edges(v, data=True, keys=True):
  287. # Skip excluded edges
  288. if data.get(partition) == nx.EdgePartition.EXCLUDED:
  289. continue
  290. new_weight = data[attr]
  291. # Return the included edge
  292. if data.get(partition) == nx.EdgePartition.INCLUDED:
  293. max_weight = new_weight
  294. edge = (u, v, key, new_weight, data)
  295. break
  296. # Find the best open edge
  297. if new_weight > max_weight:
  298. max_weight = new_weight
  299. edge = (u, v, key, new_weight, data)
  300. return edge, max_weight
  301. def edmonds_step_I2(v, desired_edge, level):
  302. """
  303. Perform step I2 from Edmonds' paper
  304. First, check if the last step I1 created a cycle. If it did not, do nothing.
  305. If it did, store the cycle for later reference and contract it.
  306. Parameters
  307. ----------
  308. v : node
  309. The current node to consider
  310. desired_edge : edge
  311. The minimum desired edge to remove from the cycle.
  312. level : int
  313. The current level, i.e. the number of cycles that have already been removed.
  314. """
  315. u = desired_edge[0]
  316. Q_nodes = nx.shortest_path(B, v, u)
  317. Q_edges = [
  318. list(B[Q_nodes[i]][vv].keys())[0] for i, vv in enumerate(Q_nodes[1:])
  319. ]
  320. Q_edges.append(desired_edge[2]) # Add the new edge key to complete the circuit
  321. # Get the edge in the circuit with the minimum weight.
  322. # Also, save the incoming weights for each node.
  323. minweight = INF
  324. minedge = None
  325. Q_incoming_weight = {}
  326. for edge_key in Q_edges:
  327. u, v, data = B_edge_index[edge_key]
  328. w = data[attr]
  329. # We cannot remove an included edge, even if it is the
  330. # minimum edge in the circuit
  331. Q_incoming_weight[v] = w
  332. if data.get(partition) == nx.EdgePartition.INCLUDED:
  333. continue
  334. if w < minweight:
  335. minweight = w
  336. minedge = edge_key
  337. circuits.append(Q_edges)
  338. minedge_circuit.append(minedge)
  339. graphs.append((G.copy(), G_edge_index.copy()))
  340. branchings.append((B.copy(), B_edge_index.copy()))
  341. # Mutate the graph to contract the circuit
  342. new_node = new_node_base_name + str(level)
  343. G.add_node(new_node)
  344. new_edges = []
  345. for u, v, key, data in G.edges(data=True, keys=True):
  346. if u in Q_incoming_weight:
  347. if v in Q_incoming_weight:
  348. # Circuit edge. For the moment do nothing,
  349. # eventually it will be removed.
  350. continue
  351. else:
  352. # Outgoing edge from a node in the circuit.
  353. # Make it come from the new node instead
  354. dd = data.copy()
  355. new_edges.append((new_node, v, key, dd))
  356. else:
  357. if v in Q_incoming_weight:
  358. # Incoming edge to the circuit.
  359. # Update it's weight
  360. w = data[attr]
  361. w += minweight - Q_incoming_weight[v]
  362. dd = data.copy()
  363. dd[attr] = w
  364. new_edges.append((u, new_node, key, dd))
  365. else:
  366. # Outside edge. No modification needed
  367. continue
  368. for node in Q_nodes:
  369. edmonds_remove_node(G, G_edge_index, node)
  370. edmonds_remove_node(B, B_edge_index, node)
  371. selected_nodes.difference_update(set(Q_nodes))
  372. for u, v, key, data in new_edges:
  373. edmonds_add_edge(G, G_edge_index, u, v, key, **data)
  374. if candidate_attr in data:
  375. del data[candidate_attr]
  376. edmonds_add_edge(B, B_edge_index, u, v, key, **data)
  377. uf.union(u, v)
  378. def is_root(G, u, edgekeys):
  379. """
  380. Returns True if `u` is a root node in G.
  381. Node `u` is a root node if its in-degree over the specified edges is zero.
  382. Parameters
  383. ----------
  384. G : Graph
  385. The current graph.
  386. u : node
  387. The node in `G` to check if it is a root.
  388. edgekeys : iterable of edges
  389. The edges for which to check if `u` is a root of.
  390. """
  391. if u not in G:
  392. raise Exception(f"{u!r} not in G")
  393. for v in G.pred[u]:
  394. for edgekey in G.pred[u][v]:
  395. if edgekey in edgekeys:
  396. return False, edgekey
  397. else:
  398. return True, None
  399. nodes = iter(list(G.nodes))
  400. while True:
  401. try:
  402. v = next(nodes)
  403. except StopIteration:
  404. # If there are no more new nodes to consider, then we should
  405. # meet stopping condition (b) from the paper:
  406. # (b) every node of G^i is in D^i and E^i is a branching
  407. assert len(G) == len(B)
  408. if len(B):
  409. assert is_branching(B)
  410. graphs.append((G.copy(), G_edge_index.copy()))
  411. branchings.append((B.copy(), B_edge_index.copy()))
  412. circuits.append([])
  413. minedge_circuit.append(None)
  414. break
  415. else:
  416. #####################
  417. ### BEGIN STEP I1 ###
  418. #####################
  419. # This is a very simple step, so I don't think it needs a method of it's own
  420. if v in selected_nodes:
  421. continue
  422. selected_nodes.add(v)
  423. B.add_node(v)
  424. desired_edge, desired_edge_weight = edmonds_find_desired_edge(v)
  425. # There might be no desired edge if all edges are excluded or
  426. # v is the last node to be added to B, the ultimate root of the branching
  427. if desired_edge is not None and desired_edge_weight > 0:
  428. u = desired_edge[0]
  429. # Flag adding the edge will create a circuit before merging the two
  430. # connected components of u and v in B
  431. circuit = uf[u] == uf[v]
  432. dd = {attr: desired_edge_weight}
  433. if desired_edge[4].get(partition) is not None:
  434. dd[partition] = desired_edge[4].get(partition)
  435. edmonds_add_edge(B, B_edge_index, u, v, desired_edge[2], **dd)
  436. G[u][v][desired_edge[2]][candidate_attr] = True
  437. uf.union(u, v)
  438. ###################
  439. ### END STEP I1 ###
  440. ###################
  441. #####################
  442. ### BEGIN STEP I2 ###
  443. #####################
  444. if circuit:
  445. edmonds_step_I2(v, desired_edge, level)
  446. nodes = iter(list(G.nodes()))
  447. level += 1
  448. ###################
  449. ### END STEP I2 ###
  450. ###################
  451. #####################
  452. ### BEGIN STEP I3 ###
  453. #####################
  454. # Create a new graph of the same class as the input graph
  455. H = G_original.__class__()
  456. # Start with the branching edges in the last level.
  457. edges = set(branchings[level][1])
  458. while level > 0:
  459. level -= 1
  460. # The current level is i, and we start counting from 0.
  461. #
  462. # We need the node at level i+1 that results from merging a circuit
  463. # at level i. basename_0 is the first merged node and this happens
  464. # at level 1. That is basename_0 is a node at level 1 that results
  465. # from merging a circuit at level 0.
  466. merged_node = new_node_base_name + str(level)
  467. circuit = circuits[level]
  468. isroot, edgekey = is_root(graphs[level + 1][0], merged_node, edges)
  469. edges.update(circuit)
  470. if isroot:
  471. minedge = minedge_circuit[level]
  472. if minedge is None:
  473. raise Exception
  474. # Remove the edge in the cycle with minimum weight
  475. edges.remove(minedge)
  476. else:
  477. # We have identified an edge at the next higher level that
  478. # transitions into the merged node at this level. That edge
  479. # transitions to some corresponding node at the current level.
  480. #
  481. # We want to remove an edge from the cycle that transitions
  482. # into the corresponding node, otherwise the result would not
  483. # be a branching.
  484. G, G_edge_index = graphs[level]
  485. target = G_edge_index[edgekey][1]
  486. for edgekey in circuit:
  487. u, v, data = G_edge_index[edgekey]
  488. if v == target:
  489. break
  490. else:
  491. raise Exception("Couldn't find edge incoming to merged node.")
  492. edges.remove(edgekey)
  493. H.add_nodes_from(G_original)
  494. for edgekey in edges:
  495. u, v, d = graphs[0][1][edgekey]
  496. dd = {attr: d[attr]}
  497. if preserve_attrs:
  498. for key, value in d.items():
  499. if key not in [attr, candidate_attr]:
  500. dd[key] = value
  501. H.add_edge(u, v, **dd)
  502. ###################
  503. ### END STEP I3 ###
  504. ###################
  505. return H
  506. @nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
  507. def minimum_branching(
  508. G, attr="weight", default=1, preserve_attrs=False, partition=None
  509. ):
  510. for _, _, d in G.edges(data=True):
  511. d[attr] = -d.get(attr, default)
  512. nx._clear_cache(G)
  513. B = maximum_branching(G, attr, default, preserve_attrs, partition)
  514. for _, _, d in G.edges(data=True):
  515. d[attr] = -d.get(attr, default)
  516. nx._clear_cache(G)
  517. for _, _, d in B.edges(data=True):
  518. d[attr] = -d.get(attr, default)
  519. nx._clear_cache(B)
  520. return B
  521. @nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
  522. def minimal_branching(
  523. G, /, *, attr="weight", default=1, preserve_attrs=False, partition=None
  524. ):
  525. """
  526. Returns a minimal branching from `G`.
  527. A minimal branching is a branching similar to a minimal arborescence but
  528. without the requirement that the result is actually a spanning arborescence.
  529. This allows minimal branchinges to be computed over graphs which may not
  530. have arborescence (such as multiple components).
  531. Parameters
  532. ----------
  533. G : (multi)digraph-like
  534. The graph to be searched.
  535. attr : str
  536. The edge attribute used in determining optimality.
  537. default : float
  538. The value of the edge attribute used if an edge does not have
  539. the attribute `attr`.
  540. preserve_attrs : bool
  541. If True, preserve the other attributes of the original graph (that are not
  542. passed to `attr`)
  543. partition : str
  544. The key for the edge attribute containing the partition
  545. data on the graph. Edges can be included, excluded or open using the
  546. `EdgePartition` enum.
  547. Returns
  548. -------
  549. B : (multi)digraph-like
  550. A minimal branching.
  551. """
  552. max_weight = -INF
  553. min_weight = INF
  554. for _, _, w in G.edges(data=attr, default=default):
  555. if w > max_weight:
  556. max_weight = w
  557. if w < min_weight:
  558. min_weight = w
  559. for _, _, d in G.edges(data=True):
  560. # Transform the weights so that the minimum weight is larger than
  561. # the difference between the max and min weights. This is important
  562. # in order to prevent the edge weights from becoming negative during
  563. # computation
  564. d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
  565. nx._clear_cache(G)
  566. B = maximum_branching(G, attr, default, preserve_attrs, partition)
  567. # Reverse the weight transformations
  568. for _, _, d in G.edges(data=True):
  569. d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
  570. nx._clear_cache(G)
  571. for _, _, d in B.edges(data=True):
  572. d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
  573. nx._clear_cache(B)
  574. return B
  575. @nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
  576. def maximum_spanning_arborescence(
  577. G, attr="weight", default=1, preserve_attrs=False, partition=None
  578. ):
  579. # In order to use the same algorithm is the maximum branching, we need to adjust
  580. # the weights of the graph. The branching algorithm can choose to not include an
  581. # edge if it doesn't help find a branching, mainly triggered by edges with negative
  582. # weights.
  583. #
  584. # To prevent this from happening while trying to find a spanning arborescence, we
  585. # just have to tweak the edge weights so that they are all positive and cannot
  586. # become negative during the branching algorithm, find the maximum branching and
  587. # then return them to their original values.
  588. min_weight = INF
  589. max_weight = -INF
  590. for _, _, w in G.edges(data=attr, default=default):
  591. if w < min_weight:
  592. min_weight = w
  593. if w > max_weight:
  594. max_weight = w
  595. for _, _, d in G.edges(data=True):
  596. d[attr] = d.get(attr, default) - min_weight + 1 - (min_weight - max_weight)
  597. nx._clear_cache(G)
  598. B = maximum_branching(G, attr, default, preserve_attrs, partition)
  599. for _, _, d in G.edges(data=True):
  600. d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight)
  601. nx._clear_cache(G)
  602. for _, _, d in B.edges(data=True):
  603. d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight)
  604. nx._clear_cache(B)
  605. if not is_arborescence(B):
  606. raise nx.exception.NetworkXException("No maximum spanning arborescence in G.")
  607. return B
  608. @nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
  609. def minimum_spanning_arborescence(
  610. G, attr="weight", default=1, preserve_attrs=False, partition=None
  611. ):
  612. B = minimal_branching(
  613. G,
  614. attr=attr,
  615. default=default,
  616. preserve_attrs=preserve_attrs,
  617. partition=partition,
  618. )
  619. if not is_arborescence(B):
  620. raise nx.exception.NetworkXException("No minimum spanning arborescence in G.")
  621. return B
  622. docstring_branching = """
  623. Returns a {kind} {style} from G.
  624. Parameters
  625. ----------
  626. G : (multi)digraph-like
  627. The graph to be searched.
  628. attr : str
  629. The edge attribute used to in determining optimality.
  630. default : float
  631. The value of the edge attribute used if an edge does not have
  632. the attribute `attr`.
  633. preserve_attrs : bool
  634. If True, preserve the other attributes of the original graph (that are not
  635. passed to `attr`)
  636. partition : str
  637. The key for the edge attribute containing the partition
  638. data on the graph. Edges can be included, excluded or open using the
  639. `EdgePartition` enum.
  640. Returns
  641. -------
  642. B : (multi)digraph-like
  643. A {kind} {style}.
  644. """
  645. docstring_arborescence = (
  646. docstring_branching
  647. + """
  648. Raises
  649. ------
  650. NetworkXException
  651. If the graph does not contain a {kind} {style}.
  652. """
  653. )
  654. maximum_branching.__doc__ = docstring_branching.format(
  655. kind="maximum", style="branching"
  656. )
  657. minimum_branching.__doc__ = (
  658. docstring_branching.format(kind="minimum", style="branching")
  659. + """
  660. See Also
  661. --------
  662. minimal_branching
  663. """
  664. )
  665. maximum_spanning_arborescence.__doc__ = docstring_arborescence.format(
  666. kind="maximum", style="spanning arborescence"
  667. )
  668. minimum_spanning_arborescence.__doc__ = docstring_arborescence.format(
  669. kind="minimum", style="spanning arborescence"
  670. )
  671. class ArborescenceIterator:
  672. """
  673. Iterate over all spanning arborescences of a graph in either increasing or
  674. decreasing cost.
  675. Notes
  676. -----
  677. This iterator uses the partition scheme from [1]_ (included edges,
  678. excluded edges and open edges). It generates minimum spanning
  679. arborescences using a modified Edmonds' Algorithm which respects the
  680. partition of edges. For arborescences with the same weight, ties are
  681. broken arbitrarily.
  682. References
  683. ----------
  684. .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning
  685. trees in order of increasing cost, Pesquisa Operacional, 2005-08,
  686. Vol. 25 (2), p. 219-229,
  687. https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en
  688. """
  689. @dataclass(order=True)
  690. class Partition:
  691. """
  692. This dataclass represents a partition and stores a dict with the edge
  693. data and the weight of the minimum spanning arborescence of the
  694. partition dict.
  695. """
  696. mst_weight: float
  697. partition_dict: dict = field(compare=False)
  698. def __copy__(self):
  699. return ArborescenceIterator.Partition(
  700. self.mst_weight, self.partition_dict.copy()
  701. )
  702. def __init__(self, G, weight="weight", minimum=True, init_partition=None):
  703. """
  704. Initialize the iterator
  705. Parameters
  706. ----------
  707. G : nx.DiGraph
  708. The directed graph which we need to iterate trees over
  709. weight : String, default = "weight"
  710. The edge attribute used to store the weight of the edge
  711. minimum : bool, default = True
  712. Return the trees in increasing order while true and decreasing order
  713. while false.
  714. init_partition : tuple, default = None
  715. In the case that certain edges have to be included or excluded from
  716. the arborescences, `init_partition` should be in the form
  717. `(included_edges, excluded_edges)` where each edges is a
  718. `(u, v)`-tuple inside an iterable such as a list or set.
  719. """
  720. self.G = G.copy()
  721. self.weight = weight
  722. self.minimum = minimum
  723. self.method = (
  724. minimum_spanning_arborescence if minimum else maximum_spanning_arborescence
  725. )
  726. # Randomly create a key for an edge attribute to hold the partition data
  727. self.partition_key = (
  728. "ArborescenceIterators super secret partition attribute name"
  729. )
  730. if init_partition is not None:
  731. partition_dict = {}
  732. for e in init_partition[0]:
  733. partition_dict[e] = nx.EdgePartition.INCLUDED
  734. for e in init_partition[1]:
  735. partition_dict[e] = nx.EdgePartition.EXCLUDED
  736. self.init_partition = ArborescenceIterator.Partition(0, partition_dict)
  737. else:
  738. self.init_partition = None
  739. def __iter__(self):
  740. """
  741. Returns
  742. -------
  743. ArborescenceIterator
  744. The iterator object for this graph
  745. """
  746. self.partition_queue = PriorityQueue()
  747. self._clear_partition(self.G)
  748. # Write the initial partition if it exists.
  749. if self.init_partition is not None:
  750. self._write_partition(self.init_partition)
  751. mst_weight = self.method(
  752. self.G,
  753. self.weight,
  754. partition=self.partition_key,
  755. preserve_attrs=True,
  756. ).size(weight=self.weight)
  757. self.partition_queue.put(
  758. self.Partition(
  759. mst_weight if self.minimum else -mst_weight,
  760. (
  761. {}
  762. if self.init_partition is None
  763. else self.init_partition.partition_dict
  764. ),
  765. )
  766. )
  767. return self
  768. def __next__(self):
  769. """
  770. Returns
  771. -------
  772. (multi)Graph
  773. The spanning tree of next greatest weight, which ties broken
  774. arbitrarily.
  775. """
  776. if self.partition_queue.empty():
  777. del self.G, self.partition_queue
  778. raise StopIteration
  779. partition = self.partition_queue.get()
  780. self._write_partition(partition)
  781. next_arborescence = self.method(
  782. self.G,
  783. self.weight,
  784. partition=self.partition_key,
  785. preserve_attrs=True,
  786. )
  787. self._partition(partition, next_arborescence)
  788. self._clear_partition(next_arborescence)
  789. return next_arborescence
  790. def _partition(self, partition, partition_arborescence):
  791. """
  792. Create new partitions based of the minimum spanning tree of the
  793. current minimum partition.
  794. Parameters
  795. ----------
  796. partition : Partition
  797. The Partition instance used to generate the current minimum spanning
  798. tree.
  799. partition_arborescence : nx.Graph
  800. The minimum spanning arborescence of the input partition.
  801. """
  802. # create two new partitions with the data from the input partition dict
  803. p1 = self.Partition(0, partition.partition_dict.copy())
  804. p2 = self.Partition(0, partition.partition_dict.copy())
  805. for e in partition_arborescence.edges:
  806. # determine if the edge was open or included
  807. if e not in partition.partition_dict:
  808. # This is an open edge
  809. p1.partition_dict[e] = nx.EdgePartition.EXCLUDED
  810. p2.partition_dict[e] = nx.EdgePartition.INCLUDED
  811. self._write_partition(p1)
  812. try:
  813. p1_mst = self.method(
  814. self.G,
  815. self.weight,
  816. partition=self.partition_key,
  817. preserve_attrs=True,
  818. )
  819. p1_mst_weight = p1_mst.size(weight=self.weight)
  820. p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight
  821. self.partition_queue.put(p1.__copy__())
  822. except nx.NetworkXException:
  823. pass
  824. p1.partition_dict = p2.partition_dict.copy()
  825. def _write_partition(self, partition):
  826. """
  827. Writes the desired partition into the graph to calculate the minimum
  828. spanning tree. Also, if one incoming edge is included, mark all others
  829. as excluded so that if that vertex is merged during Edmonds' algorithm
  830. we cannot still pick another of that vertex's included edges.
  831. Parameters
  832. ----------
  833. partition : Partition
  834. A Partition dataclass describing a partition on the edges of the
  835. graph.
  836. """
  837. for u, v, d in self.G.edges(data=True):
  838. if (u, v) in partition.partition_dict:
  839. d[self.partition_key] = partition.partition_dict[(u, v)]
  840. else:
  841. d[self.partition_key] = nx.EdgePartition.OPEN
  842. nx._clear_cache(self.G)
  843. for n in self.G:
  844. included_count = 0
  845. excluded_count = 0
  846. for u, v, d in self.G.in_edges(nbunch=n, data=True):
  847. if d.get(self.partition_key) == nx.EdgePartition.INCLUDED:
  848. included_count += 1
  849. elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED:
  850. excluded_count += 1
  851. # Check that if there is an included edges, all other incoming ones
  852. # are excluded. If not fix it!
  853. if included_count == 1 and excluded_count != self.G.in_degree(n) - 1:
  854. for u, v, d in self.G.in_edges(nbunch=n, data=True):
  855. if d.get(self.partition_key) != nx.EdgePartition.INCLUDED:
  856. d[self.partition_key] = nx.EdgePartition.EXCLUDED
  857. def _clear_partition(self, G):
  858. """
  859. Removes partition data from the graph
  860. """
  861. for u, v, d in G.edges(data=True):
  862. if self.partition_key in d:
  863. del d[self.partition_key]
  864. nx._clear_cache(self.G)