utils.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. """
  2. Utility classes and functions for network flow algorithms.
  3. """
  4. from collections import deque
  5. import networkx as nx
  6. __all__ = [
  7. "CurrentEdge",
  8. "Level",
  9. "GlobalRelabelThreshold",
  10. "build_residual_network",
  11. "detect_unboundedness",
  12. "build_flow_dict",
  13. ]
  14. class CurrentEdge:
  15. """Mechanism for iterating over out-edges incident to a node in a circular
  16. manner. StopIteration exception is raised when wraparound occurs.
  17. """
  18. __slots__ = ("_edges", "_it", "_curr")
  19. def __init__(self, edges):
  20. self._edges = edges
  21. if self._edges:
  22. self._rewind()
  23. def get(self):
  24. return self._curr
  25. def move_to_next(self):
  26. try:
  27. self._curr = next(self._it)
  28. except StopIteration:
  29. self._rewind()
  30. raise
  31. def _rewind(self):
  32. self._it = iter(self._edges.items())
  33. self._curr = next(self._it)
  34. def __eq__(self, other):
  35. return (getattr(self, "_curr", None), self._edges) == (
  36. (getattr(other, "_curr", None), other._edges)
  37. )
  38. class Level:
  39. """Active and inactive nodes in a level."""
  40. __slots__ = ("active", "inactive")
  41. def __init__(self):
  42. self.active = set()
  43. self.inactive = set()
  44. class GlobalRelabelThreshold:
  45. """Measurement of work before the global relabeling heuristic should be
  46. applied.
  47. """
  48. def __init__(self, n, m, freq):
  49. self._threshold = (n + m) / freq if freq else float("inf")
  50. self._work = 0
  51. def add_work(self, work):
  52. self._work += work
  53. def is_reached(self):
  54. return self._work >= self._threshold
  55. def clear_work(self):
  56. self._work = 0
  57. @nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True)
  58. def build_residual_network(G, capacity):
  59. """Build a residual network and initialize a zero flow.
  60. The residual network :samp:`R` from an input graph :samp:`G` has the
  61. same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair
  62. of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a
  63. self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists
  64. in :samp:`G`.
  65. For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']`
  66. is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists
  67. in :samp:`G` or zero otherwise. If the capacity is infinite,
  68. :samp:`R[u][v]['capacity']` will have a high arbitrary finite value
  69. that does not affect the solution of the problem. This value is stored in
  70. :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`,
  71. :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and
  72. satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`.
  73. The flow value, defined as the total flow into :samp:`t`, the sink, is
  74. stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not
  75. specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such
  76. that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum
  77. :samp:`s`-:samp:`t` cut.
  78. """
  79. if G.is_multigraph():
  80. raise nx.NetworkXError("MultiGraph and MultiDiGraph not supported (yet).")
  81. R = nx.DiGraph()
  82. R.__networkx_cache__ = None # Disable caching
  83. R.add_nodes_from(G)
  84. inf = float("inf")
  85. # Extract edges with positive capacities. Self loops excluded.
  86. edge_list = [
  87. (u, v, attr)
  88. for u, v, attr in G.edges(data=True)
  89. if u != v and attr.get(capacity, inf) > 0
  90. ]
  91. # Simulate infinity with three times the sum of the finite edge capacities
  92. # or any positive value if the sum is zero. This allows the
  93. # infinite-capacity edges to be distinguished for unboundedness detection
  94. # and directly participate in residual capacity calculation. If the maximum
  95. # flow is finite, these edges cannot appear in the minimum cut and thus
  96. # guarantee correctness. Since the residual capacity of an
  97. # infinite-capacity edge is always at least 2/3 of inf, while that of an
  98. # finite-capacity edge is at most 1/3 of inf, if an operation moves more
  99. # than 1/3 of inf units of flow to t, there must be an infinite-capacity
  100. # s-t path in G.
  101. inf = (
  102. 3
  103. * sum(
  104. attr[capacity]
  105. for u, v, attr in edge_list
  106. if capacity in attr and attr[capacity] != inf
  107. )
  108. or 1
  109. )
  110. if G.is_directed():
  111. for u, v, attr in edge_list:
  112. r = min(attr.get(capacity, inf), inf)
  113. if not R.has_edge(u, v):
  114. # Both (u, v) and (v, u) must be present in the residual
  115. # network.
  116. R.add_edge(u, v, capacity=r)
  117. R.add_edge(v, u, capacity=0)
  118. else:
  119. # The edge (u, v) was added when (v, u) was visited.
  120. R[u][v]["capacity"] = r
  121. else:
  122. for u, v, attr in edge_list:
  123. # Add a pair of edges with equal residual capacities.
  124. r = min(attr.get(capacity, inf), inf)
  125. R.add_edge(u, v, capacity=r)
  126. R.add_edge(v, u, capacity=r)
  127. # Record the value simulating infinity.
  128. R.graph["inf"] = inf
  129. return R
  130. @nx._dispatchable(
  131. graphs="R",
  132. preserve_edge_attrs={"R": {"capacity": float("inf")}},
  133. preserve_graph_attrs=True,
  134. )
  135. def detect_unboundedness(R, s, t):
  136. """Detect an infinite-capacity s-t path in R."""
  137. q = deque([s])
  138. seen = {s}
  139. inf = R.graph["inf"]
  140. while q:
  141. u = q.popleft()
  142. for v, attr in R[u].items():
  143. if attr["capacity"] == inf and v not in seen:
  144. if v == t:
  145. raise nx.NetworkXUnbounded(
  146. "Infinite capacity path, flow unbounded above."
  147. )
  148. seen.add(v)
  149. q.append(v)
  150. @nx._dispatchable(graphs={"G": 0, "R": 1}, preserve_edge_attrs={"R": {"flow": None}})
  151. def build_flow_dict(G, R):
  152. """Build a flow dictionary from a residual network."""
  153. flow_dict = {}
  154. for u in G:
  155. flow_dict[u] = {v: 0 for v in G[u]}
  156. flow_dict[u].update(
  157. (v, attr["flow"]) for v, attr in R[u].items() if attr["flow"] > 0
  158. )
  159. return flow_dict