segment_tree.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import operator
  2. from typing import Any, Optional
  3. class SegmentTree:
  4. """A Segment Tree data structure.
  5. https://en.wikipedia.org/wiki/Segment_tree
  6. Can be used as regular array, but with two important differences:
  7. a) Setting an item's value is slightly slower. It is O(lg capacity),
  8. instead of O(1).
  9. b) Offers efficient `reduce` operation which reduces the tree's values
  10. over some specified contiguous subsequence of items in the array.
  11. Operation could be e.g. min/max/sum.
  12. The data is stored in a list, where the length is 2 * capacity.
  13. The second half of the list stores the actual values for each index, so if
  14. capacity=8, values are stored at indices 8 to 15. The first half of the
  15. array contains the reduced-values of the different (binary divided)
  16. segments, e.g. (capacity=4):
  17. 0=not used
  18. 1=reduced-value over all elements (array indices 4 to 7).
  19. 2=reduced-value over array indices (4 and 5).
  20. 3=reduced-value over array indices (6 and 7).
  21. 4-7: values of the tree.
  22. NOTE that the values of the tree are accessed by indices starting at 0, so
  23. `tree[0]` accesses `internal_array[4]` in the above example.
  24. """
  25. def __init__(
  26. self, capacity: int, operation: Any, neutral_element: Optional[Any] = None
  27. ):
  28. """Initializes a Segment Tree object.
  29. Args:
  30. capacity: Total size of the array - must be a power of two.
  31. operation: Lambda obj, obj -> obj
  32. The operation for combining elements (eg. sum, max).
  33. Must be a mathematical group together with the set of
  34. possible values for array elements.
  35. neutral_element (Optional[obj]): The neutral element for
  36. `operation`. Use None for automatically finding a value:
  37. max: float("-inf"), min: float("inf"), sum: 0.0.
  38. """
  39. assert (
  40. capacity > 0 and capacity & (capacity - 1) == 0
  41. ), "Capacity must be positive and a power of 2!"
  42. self.capacity = capacity
  43. if neutral_element is None:
  44. neutral_element = (
  45. 0.0
  46. if operation is operator.add
  47. else float("-inf")
  48. if operation is max
  49. else float("inf")
  50. )
  51. self.neutral_element = neutral_element
  52. self.value = [self.neutral_element for _ in range(2 * capacity)]
  53. self.operation = operation
  54. def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
  55. """Applies `self.operation` to subsequence of our values.
  56. Subsequence is contiguous, includes `start` and excludes `end`.
  57. self.operation(
  58. arr[start], operation(arr[start+1], operation(... arr[end])))
  59. Args:
  60. start: Start index to apply reduction to.
  61. end (Optional[int]): End index to apply reduction to (excluded).
  62. Returns:
  63. any: The result of reducing self.operation over the specified
  64. range of `self._value` elements.
  65. """
  66. if end is None:
  67. end = self.capacity
  68. elif end < 0:
  69. end += self.capacity
  70. # Init result with neutral element.
  71. result = self.neutral_element
  72. # Map start/end to our actual index space (second half of array).
  73. start += self.capacity
  74. end += self.capacity
  75. # Example:
  76. # internal-array (first half=sums, second half=actual values):
  77. # 0 1 2 3 | 4 5 6 7
  78. # - 6 1 5 | 1 0 2 3
  79. # tree.sum(0, 3) = 3
  80. # internally: start=4, end=7 -> sum values 1 0 2 = 3.
  81. # Iterate over tree starting in the actual-values (second half)
  82. # section.
  83. # 1) start=4 is even -> do nothing.
  84. # 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
  85. # 3) int-divide start and end by 2: start=2, end=3
  86. # 4) start still smaller end -> iterate once more.
  87. # 5) start=2 is even -> do nothing.
  88. # 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
  89. # NOTE: This adds the sum of indices 4 and 5 to the result.
  90. # Iterate as long as start != end.
  91. while start < end:
  92. # If start is odd: Add its value to result and move start to
  93. # next even value.
  94. if start & 1:
  95. result = self.operation(result, self.value[start])
  96. start += 1
  97. # If end is odd: Move end to previous even value, then add its
  98. # value to result. NOTE: This takes care of excluding `end` in any
  99. # situation.
  100. if end & 1:
  101. end -= 1
  102. result = self.operation(result, self.value[end])
  103. # Divide both start and end by 2 to make them "jump" into the
  104. # next upper level reduce-index space.
  105. start //= 2
  106. end //= 2
  107. # Then repeat till start == end.
  108. return result
  109. def __setitem__(self, idx: int, val: float) -> None:
  110. """
  111. Inserts/overwrites a value in/into the tree.
  112. Args:
  113. idx: The index to insert to. Must be in [0, `self.capacity`)
  114. val: The value to insert.
  115. """
  116. assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
  117. # Index of the leaf to insert into (always insert in "second half"
  118. # of the tree, the first half is reserved for already calculated
  119. # reduction-values).
  120. idx += self.capacity
  121. self.value[idx] = val
  122. # Recalculate all affected reduction values (in "first half" of tree).
  123. idx = idx >> 1 # Divide by 2 (faster than division).
  124. while idx >= 1:
  125. update_idx = 2 * idx # calculate only once
  126. # Update the reduction value at the correct "first half" idx.
  127. self.value[idx] = self.operation(
  128. self.value[update_idx], self.value[update_idx + 1]
  129. )
  130. idx = idx >> 1 # Divide by 2 (faster than division).
  131. def __getitem__(self, idx: int) -> Any:
  132. assert 0 <= idx < self.capacity
  133. return self.value[idx + self.capacity]
  134. def get_state(self):
  135. return self.value
  136. def set_state(self, state):
  137. assert len(state) == self.capacity * 2
  138. self.value = state
  139. class SumSegmentTree(SegmentTree):
  140. """A SegmentTree with the reduction `operation`=operator.add."""
  141. def __init__(self, capacity: int):
  142. super(SumSegmentTree, self).__init__(capacity=capacity, operation=operator.add)
  143. def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
  144. """Returns the sum over a sub-segment of the tree."""
  145. return self.reduce(start, end)
  146. def find_prefixsum_idx(self, prefixsum: float) -> int:
  147. """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
  148. Args:
  149. prefixsum: `prefixsum` upper bound in above constraint.
  150. Returns:
  151. int: Largest possible index (i) satisfying above constraint.
  152. """
  153. assert 0 <= prefixsum <= self.sum() + 1e-5
  154. # Global sum node.
  155. idx = 1
  156. # Edge case when prefixsum can clip into the invalid regions
  157. # https://github.com/ray-project/ray/issues/54284
  158. if prefixsum >= self.value[idx]:
  159. prefixsum = self.value[idx] - 1e-5
  160. # While non-leaf (first half of tree).
  161. while idx < self.capacity:
  162. update_idx = 2 * idx
  163. if self.value[update_idx] > prefixsum:
  164. idx = update_idx
  165. else:
  166. prefixsum -= self.value[update_idx]
  167. idx = update_idx + 1
  168. return idx - self.capacity
  169. class MinSegmentTree(SegmentTree):
  170. def __init__(self, capacity: int):
  171. super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
  172. def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
  173. """Returns min(arr[start], ..., arr[end])"""
  174. return self.reduce(start, end)