_disjoint_set.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. Disjoint set data structure
  3. """
  4. class DisjointSet:
  5. """ Disjoint set data structure for incremental connectivity queries.
  6. .. versionadded:: 1.6.0
  7. Attributes
  8. ----------
  9. n_subsets : int
  10. The number of subsets.
  11. Methods
  12. -------
  13. add
  14. merge
  15. connected
  16. subset
  17. subset_size
  18. subsets
  19. __getitem__
  20. Notes
  21. -----
  22. This class implements the disjoint set [1]_, also known as the *union-find*
  23. or *merge-find* data structure. The *find* operation (implemented in
  24. `__getitem__`) implements the *path halving* variant. The *merge* method
  25. implements the *merge by size* variant.
  26. References
  27. ----------
  28. .. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
  29. Examples
  30. --------
  31. >>> from scipy.cluster.hierarchy import DisjointSet
  32. Initialize a disjoint set:
  33. >>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
  34. Merge some subsets:
  35. >>> disjoint_set.merge(1, 2)
  36. True
  37. >>> disjoint_set.merge(3, 'a')
  38. True
  39. >>> disjoint_set.merge('a', 'b')
  40. True
  41. >>> disjoint_set.merge('b', 'b')
  42. False
  43. Find root elements:
  44. >>> disjoint_set[2]
  45. 1
  46. >>> disjoint_set['b']
  47. 3
  48. Test connectivity:
  49. >>> disjoint_set.connected(1, 2)
  50. True
  51. >>> disjoint_set.connected(1, 'b')
  52. False
  53. List elements in disjoint set:
  54. >>> list(disjoint_set)
  55. [1, 2, 3, 'a', 'b']
  56. Get the subset containing 'a':
  57. >>> disjoint_set.subset('a')
  58. {'a', 3, 'b'}
  59. Get the size of the subset containing 'a' (without actually instantiating
  60. the subset):
  61. >>> disjoint_set.subset_size('a')
  62. 3
  63. Get all subsets in the disjoint set:
  64. >>> disjoint_set.subsets()
  65. [{1, 2}, {'a', 3, 'b'}]
  66. """
  67. def __init__(self, elements=None):
  68. self.n_subsets = 0
  69. self._sizes = {}
  70. self._parents = {}
  71. # _nbrs is a circular linked list which links connected elements.
  72. self._nbrs = {}
  73. # _indices tracks the element insertion order in `__iter__`.
  74. self._indices = {}
  75. if elements is not None:
  76. for x in elements:
  77. self.add(x)
  78. def __iter__(self):
  79. """Returns an iterator of the elements in the disjoint set.
  80. Elements are ordered by insertion order.
  81. """
  82. return iter(self._indices)
  83. def __len__(self):
  84. return len(self._indices)
  85. def __contains__(self, x):
  86. return x in self._indices
  87. def __getitem__(self, x):
  88. """Find the root element of `x`.
  89. Parameters
  90. ----------
  91. x : hashable object
  92. Input element.
  93. Returns
  94. -------
  95. root : hashable object
  96. Root element of `x`.
  97. """
  98. if x not in self._indices:
  99. raise KeyError(x)
  100. # find by "path halving"
  101. parents = self._parents
  102. while self._indices[x] != self._indices[parents[x]]:
  103. parents[x] = parents[parents[x]]
  104. x = parents[x]
  105. return x
  106. def add(self, x):
  107. """Add element `x` to disjoint set
  108. """
  109. if x in self._indices:
  110. return
  111. self._sizes[x] = 1
  112. self._parents[x] = x
  113. self._nbrs[x] = x
  114. self._indices[x] = len(self._indices)
  115. self.n_subsets += 1
  116. def merge(self, x, y):
  117. """Merge the subsets of `x` and `y`.
  118. The smaller subset (the child) is merged into the larger subset (the
  119. parent). If the subsets are of equal size, the root element which was
  120. first inserted into the disjoint set is selected as the parent.
  121. Parameters
  122. ----------
  123. x, y : hashable object
  124. Elements to merge.
  125. Returns
  126. -------
  127. merged : bool
  128. True if `x` and `y` were in disjoint sets, False otherwise.
  129. """
  130. xr = self[x]
  131. yr = self[y]
  132. if self._indices[xr] == self._indices[yr]:
  133. return False
  134. sizes = self._sizes
  135. if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
  136. xr, yr = yr, xr
  137. self._parents[yr] = xr
  138. self._sizes[xr] += self._sizes[yr]
  139. self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
  140. self.n_subsets -= 1
  141. return True
  142. def connected(self, x, y):
  143. """Test whether `x` and `y` are in the same subset.
  144. Parameters
  145. ----------
  146. x, y : hashable object
  147. Elements to test.
  148. Returns
  149. -------
  150. result : bool
  151. True if `x` and `y` are in the same set, False otherwise.
  152. """
  153. return self._indices[self[x]] == self._indices[self[y]]
  154. def subset(self, x):
  155. """Get the subset containing `x`.
  156. Parameters
  157. ----------
  158. x : hashable object
  159. Input element.
  160. Returns
  161. -------
  162. result : set
  163. Subset containing `x`.
  164. """
  165. if x not in self._indices:
  166. raise KeyError(x)
  167. result = [x]
  168. nxt = self._nbrs[x]
  169. while self._indices[nxt] != self._indices[x]:
  170. result.append(nxt)
  171. nxt = self._nbrs[nxt]
  172. return set(result)
  173. def subset_size(self, x):
  174. """Get the size of the subset containing `x`.
  175. Note that this method is faster than ``len(self.subset(x))`` because
  176. the size is directly read off an internal field, without the need to
  177. instantiate the full subset.
  178. Parameters
  179. ----------
  180. x : hashable object
  181. Input element.
  182. Returns
  183. -------
  184. result : int
  185. Size of the subset containing `x`.
  186. """
  187. return self._sizes[self[x]]
  188. def subsets(self):
  189. """Get all the subsets in the disjoint set.
  190. Returns
  191. -------
  192. result : list
  193. Subsets in the disjoint set.
  194. """
  195. result = []
  196. visited = set()
  197. for x in self:
  198. if x not in visited:
  199. xset = self.subset(x)
  200. visited.update(xset)
  201. result.append(xset)
  202. return result