lukes.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """Lukes Algorithm for exact optimal weighted tree partitioning."""
  2. from copy import deepcopy
  3. from functools import lru_cache
  4. from random import choice
  5. import networkx as nx
  6. from networkx.utils import not_implemented_for
  7. __all__ = ["lukes_partitioning"]
  8. D_EDGE_W = "weight"
  9. D_EDGE_VALUE = 1.0
  10. D_NODE_W = "weight"
  11. D_NODE_VALUE = 1
  12. PKEY = "partitions"
  13. CLUSTER_EVAL_CACHE_SIZE = 2048
  14. def _split_n_from(n, min_size_of_first_part):
  15. # splits j in two parts of which the first is at least
  16. # the second argument
  17. assert n >= min_size_of_first_part
  18. for p1 in range(min_size_of_first_part, n + 1):
  19. yield p1, n - p1
  20. @nx._dispatchable(node_attrs="node_weight", edge_attrs="edge_weight")
  21. def lukes_partitioning(G, max_size, node_weight=None, edge_weight=None):
  22. """Optimal partitioning of a weighted tree using the Lukes algorithm.
  23. This algorithm partitions a connected, acyclic graph featuring integer
  24. node weights and float edge weights. The resulting clusters are such
  25. that the total weight of the nodes in each cluster does not exceed
  26. max_size and that the weight of the edges that are cut by the partition
  27. is minimum. The algorithm is based on [1]_.
  28. Parameters
  29. ----------
  30. G : NetworkX graph
  31. max_size : int
  32. Maximum weight a partition can have in terms of sum of
  33. node_weight for all nodes in the partition
  34. edge_weight : key
  35. Edge data key to use as weight. If None, the weights are all
  36. set to one.
  37. node_weight : key
  38. Node data key to use as weight. If None, the weights are all
  39. set to one. The data must be int.
  40. Returns
  41. -------
  42. partition : list
  43. A list of sets of nodes representing the clusters of the
  44. partition.
  45. Raises
  46. ------
  47. NotATree
  48. If G is not a tree.
  49. TypeError
  50. If any of the values of node_weight is not int.
  51. References
  52. ----------
  53. .. [1] Lukes, J. A. (1974).
  54. "Efficient Algorithm for the Partitioning of Trees."
  55. IBM Journal of Research and Development, 18(3), 217–224.
  56. """
  57. # First sanity check and tree preparation
  58. if not nx.is_tree(G):
  59. raise nx.NotATree("lukes_partitioning works only on trees")
  60. else:
  61. if nx.is_directed(G):
  62. root = [n for n, d in G.in_degree() if d == 0]
  63. assert len(root) == 1
  64. root = root[0]
  65. t_G = deepcopy(G)
  66. else:
  67. root = choice(list(G.nodes))
  68. # this has the desirable side effect of not inheriting attributes
  69. t_G = nx.dfs_tree(G, root)
  70. # Since we do not want to screw up the original graph,
  71. # if we have a blank attribute, we make a deepcopy
  72. if edge_weight is None or node_weight is None:
  73. safe_G = deepcopy(G)
  74. if edge_weight is None:
  75. nx.set_edge_attributes(safe_G, D_EDGE_VALUE, D_EDGE_W)
  76. edge_weight = D_EDGE_W
  77. if node_weight is None:
  78. nx.set_node_attributes(safe_G, D_NODE_VALUE, D_NODE_W)
  79. node_weight = D_NODE_W
  80. else:
  81. safe_G = G
  82. # Second sanity check
  83. # The values of node_weight MUST BE int.
  84. # I cannot see any room for duck typing without incurring serious
  85. # danger of subtle bugs.
  86. all_n_attr = nx.get_node_attributes(safe_G, node_weight).values()
  87. for x in all_n_attr:
  88. if not isinstance(x, int):
  89. raise TypeError(
  90. "lukes_partitioning needs integer "
  91. f"values for node_weight ({node_weight})"
  92. )
  93. # SUBROUTINES -----------------------
  94. # these functions are defined here for two reasons:
  95. # - brevity: we can leverage global "safe_G"
  96. # - caching: signatures are hashable
  97. @not_implemented_for("undirected")
  98. # this is intended to be called only on t_G
  99. def _leaves(gr):
  100. for x in gr.nodes:
  101. if not nx.descendants(gr, x):
  102. yield x
  103. @not_implemented_for("undirected")
  104. def _a_parent_of_leaves_only(gr):
  105. tleaves = set(_leaves(gr))
  106. for n in set(gr.nodes) - tleaves:
  107. if all(x in tleaves for x in nx.descendants(gr, n)):
  108. return n
  109. @lru_cache(CLUSTER_EVAL_CACHE_SIZE)
  110. def _value_of_cluster(cluster):
  111. valid_edges = [e for e in safe_G.edges if e[0] in cluster and e[1] in cluster]
  112. return sum(safe_G.edges[e][edge_weight] for e in valid_edges)
  113. def _value_of_partition(partition):
  114. return sum(_value_of_cluster(frozenset(c)) for c in partition)
  115. @lru_cache(CLUSTER_EVAL_CACHE_SIZE)
  116. def _weight_of_cluster(cluster):
  117. return sum(safe_G.nodes[n][node_weight] for n in cluster)
  118. def _pivot(partition, node):
  119. ccx = [c for c in partition if node in c]
  120. assert len(ccx) == 1
  121. return ccx[0]
  122. def _concatenate_or_merge(partition_1, partition_2, x, i, ref_weight):
  123. ccx = _pivot(partition_1, x)
  124. cci = _pivot(partition_2, i)
  125. merged_xi = ccx.union(cci)
  126. # We first check if we can do the merge.
  127. # If so, we do the actual calculations, otherwise we concatenate
  128. if _weight_of_cluster(frozenset(merged_xi)) <= ref_weight:
  129. cp1 = list(filter(lambda x: x != ccx, partition_1))
  130. cp2 = list(filter(lambda x: x != cci, partition_2))
  131. option_2 = [merged_xi] + cp1 + cp2
  132. return option_2, _value_of_partition(option_2)
  133. else:
  134. option_1 = partition_1 + partition_2
  135. return option_1, _value_of_partition(option_1)
  136. # INITIALIZATION -----------------------
  137. leaves = set(_leaves(t_G))
  138. for lv in leaves:
  139. t_G.nodes[lv][PKEY] = {}
  140. slot = safe_G.nodes[lv][node_weight]
  141. t_G.nodes[lv][PKEY][slot] = [{lv}]
  142. t_G.nodes[lv][PKEY][0] = [{lv}]
  143. for inner in [x for x in t_G.nodes if x not in leaves]:
  144. t_G.nodes[inner][PKEY] = {}
  145. slot = safe_G.nodes[inner][node_weight]
  146. t_G.nodes[inner][PKEY][slot] = [{inner}]
  147. nx._clear_cache(t_G)
  148. # CORE ALGORITHM -----------------------
  149. while True:
  150. x_node = _a_parent_of_leaves_only(t_G)
  151. weight_of_x = safe_G.nodes[x_node][node_weight]
  152. best_value = 0
  153. best_partition = None
  154. bp_buffer = {}
  155. x_descendants = nx.descendants(t_G, x_node)
  156. for i_node in x_descendants:
  157. for j in range(weight_of_x, max_size + 1):
  158. for a, b in _split_n_from(j, weight_of_x):
  159. if (
  160. a not in t_G.nodes[x_node][PKEY]
  161. or b not in t_G.nodes[i_node][PKEY]
  162. ):
  163. # it's not possible to form this particular weight sum
  164. continue
  165. part1 = t_G.nodes[x_node][PKEY][a]
  166. part2 = t_G.nodes[i_node][PKEY][b]
  167. part, value = _concatenate_or_merge(part1, part2, x_node, i_node, j)
  168. if j not in bp_buffer or bp_buffer[j][1] < value:
  169. # we annotate in the buffer the best partition for j
  170. bp_buffer[j] = part, value
  171. # we also keep track of the overall best partition
  172. if best_value <= value:
  173. best_value = value
  174. best_partition = part
  175. # as illustrated in Lukes, once we finished a child, we can
  176. # discharge the partitions we found into the graph
  177. # (the key phrase is make all x == x')
  178. # so that they are used by the subsequent children
  179. for w, (best_part_for_vl, vl) in bp_buffer.items():
  180. t_G.nodes[x_node][PKEY][w] = best_part_for_vl
  181. bp_buffer.clear()
  182. # the absolute best partition for this node
  183. # across all weights has to be stored at 0
  184. t_G.nodes[x_node][PKEY][0] = best_partition
  185. t_G.remove_nodes_from(x_descendants)
  186. if x_node == root:
  187. # the 0-labeled partition of root
  188. # is the optimal one for the whole tree
  189. return t_G.nodes[root][PKEY][0]