shuffle.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. """A simple distributed shuffle implementation in Ray.
  2. This utility provides a `simple_shuffle` function that can be used to
  3. redistribute M input partitions into N output partitions. It does this with
  4. a single wave of shuffle map tasks followed by a single wave of shuffle reduce
  5. tasks. Each shuffle map task generates O(N) output objects, and each shuffle
  6. reduce task consumes O(M) input objects, for a total of O(N*M) objects.
  7. To try an example 10GB shuffle, run:
  8. $ python -m ray.experimental.shuffle \
  9. --num-partitions=50 --partition-size=200e6 \
  10. --object-store-memory=1e9
  11. This will print out some statistics on the shuffle execution such as:
  12. --- Aggregate object store stats across all nodes ---
  13. Plasma memory usage 0 MiB, 0 objects, 0.0% full
  14. Spilled 9487 MiB, 2487 objects, avg write throughput 1023 MiB/s
  15. Restored 9487 MiB, 2487 objects, avg read throughput 1358 MiB/s
  16. Objects consumed by Ray tasks: 9537 MiB.
  17. Shuffled 9536 MiB in 16.579771757125854 seconds
  18. """
  19. import time
  20. from typing import Any, Callable, Iterable, List, Tuple, Union
  21. import ray
  22. from ray import ObjectRef
  23. from ray.cluster_utils import Cluster
  24. # TODO(ekl) why doesn't TypeVar() deserialize properly in Ray?
  25. # The type produced by the input reader function.
  26. InType = Any
  27. # The type produced by the output writer function.
  28. OutType = Any
  29. # Integer identifying the partition number.
  30. PartitionID = int
  31. class ObjectStoreWriter:
  32. """This class is used to stream shuffle map outputs to the object store.
  33. It can be subclassed to optimize writing (e.g., batching together small
  34. records into larger objects). This will be performance critical if your
  35. input records are small (the example shuffle uses very large records, so
  36. the naive strategy works well).
  37. """
  38. def __init__(self):
  39. self.results = []
  40. def add(self, item: InType) -> None:
  41. """Queue a single item to be written to the object store.
  42. This base implementation immediately writes each given item to the
  43. object store as a standalone object.
  44. """
  45. self.results.append(ray.put(item))
  46. def finish(self) -> List[ObjectRef]:
  47. """Return list of object refs representing written items."""
  48. return self.results
  49. class ObjectStoreWriterNonStreaming(ObjectStoreWriter):
  50. def __init__(self):
  51. self.results = []
  52. def add(self, item: InType) -> None:
  53. self.results.append(item)
  54. def finish(self) -> List[Any]:
  55. return self.results
  56. def round_robin_partitioner(
  57. input_stream: Iterable[InType], num_partitions: int
  58. ) -> Iterable[Tuple[PartitionID, InType]]:
  59. """Round robin partitions items from the input reader.
  60. You can write custom partitioning functions for your use case.
  61. Args:
  62. input_stream: Iterator over items from the input reader.
  63. num_partitions: Number of output partitions.
  64. Yields:
  65. Tuples of (partition id, input item).
  66. """
  67. i = 0
  68. for item in input_stream:
  69. yield (i, item)
  70. i += 1
  71. i %= num_partitions
  72. @ray.remote
  73. class _StatusTracker:
  74. def __init__(self):
  75. self.num_map = 0
  76. self.num_reduce = 0
  77. self.map_refs = []
  78. self.reduce_refs = []
  79. def register_objectrefs(self, map_refs, reduce_refs):
  80. self.map_refs = map_refs
  81. self.reduce_refs = reduce_refs
  82. def get_progress(self):
  83. if self.map_refs:
  84. ready, self.map_refs = ray.wait(
  85. self.map_refs,
  86. timeout=1,
  87. num_returns=len(self.map_refs),
  88. fetch_local=False,
  89. )
  90. self.num_map += len(ready)
  91. elif self.reduce_refs:
  92. ready, self.reduce_refs = ray.wait(
  93. self.reduce_refs,
  94. timeout=1,
  95. num_returns=len(self.reduce_refs),
  96. fetch_local=False,
  97. )
  98. self.num_reduce += len(ready)
  99. return self.num_map, self.num_reduce
  100. def render_progress_bar(tracker, input_num_partitions, output_num_partitions):
  101. from tqdm import tqdm
  102. num_map = 0
  103. num_reduce = 0
  104. map_bar = tqdm(total=input_num_partitions, position=0)
  105. map_bar.set_description("Map Progress.")
  106. reduce_bar = tqdm(total=output_num_partitions, position=1)
  107. reduce_bar.set_description("Reduce Progress.")
  108. while num_map < input_num_partitions or num_reduce < output_num_partitions:
  109. new_num_map, new_num_reduce = ray.get(tracker.get_progress.remote())
  110. map_bar.update(new_num_map - num_map)
  111. reduce_bar.update(new_num_reduce - num_reduce)
  112. num_map = new_num_map
  113. num_reduce = new_num_reduce
  114. time.sleep(0.1)
  115. map_bar.close()
  116. reduce_bar.close()
  117. def simple_shuffle(
  118. *,
  119. input_reader: Callable[[PartitionID], Iterable[InType]],
  120. input_num_partitions: int,
  121. output_num_partitions: int,
  122. output_writer: Callable[[PartitionID, List[Union[ObjectRef, Any]]], OutType],
  123. partitioner: Callable[
  124. [Iterable[InType], int], Iterable[PartitionID]
  125. ] = round_robin_partitioner,
  126. object_store_writer: ObjectStoreWriter = ObjectStoreWriter,
  127. tracker: _StatusTracker = None,
  128. streaming: bool = True,
  129. ) -> List[OutType]:
  130. """Simple distributed shuffle in Ray.
  131. Args:
  132. input_reader: Function that generates the input items for a
  133. partition (e.g., data records).
  134. input_num_partitions: The number of input partitions.
  135. output_num_partitions: The desired number of output partitions.
  136. output_writer: Function that consumes a iterator of items for a
  137. given output partition. It returns a single value that will be
  138. collected across all output partitions.
  139. partitioner: Partitioning function to use. Defaults to round-robin
  140. partitioning of input items.
  141. object_store_writer: Class used to write input items to the
  142. object store in an efficient way. Defaults to a naive
  143. implementation that writes each input record as one object.
  144. tracker: Tracker actor that is used to display the progress bar.
  145. streaming: Whether or not if the shuffle will be streaming.
  146. Returns:
  147. List of outputs from the output writers.
  148. """
  149. @ray.remote(num_returns=output_num_partitions)
  150. def shuffle_map(i: PartitionID) -> List[List[Union[Any, ObjectRef]]]:
  151. writers = [object_store_writer() for _ in range(output_num_partitions)]
  152. for out_i, item in partitioner(input_reader(i), output_num_partitions):
  153. writers[out_i].add(item)
  154. return [c.finish() for c in writers]
  155. @ray.remote
  156. def shuffle_reduce(
  157. i: PartitionID, *mapper_outputs: List[List[Union[Any, ObjectRef]]]
  158. ) -> OutType:
  159. input_objects = []
  160. assert len(mapper_outputs) == input_num_partitions
  161. for obj_refs in mapper_outputs:
  162. for obj_ref in obj_refs:
  163. input_objects.append(obj_ref)
  164. return output_writer(i, input_objects)
  165. shuffle_map_out = [shuffle_map.remote(i) for i in range(input_num_partitions)]
  166. shuffle_reduce_out = [
  167. shuffle_reduce.remote(
  168. j, *[shuffle_map_out[i][j] for i in range(input_num_partitions)]
  169. )
  170. for j in range(output_num_partitions)
  171. ]
  172. if tracker:
  173. tracker.register_objectrefs.remote(
  174. [map_out[0] for map_out in shuffle_map_out], shuffle_reduce_out
  175. )
  176. render_progress_bar(tracker, input_num_partitions, output_num_partitions)
  177. return ray.get(shuffle_reduce_out)
  178. def build_cluster(num_nodes, num_cpus, object_store_memory):
  179. cluster = Cluster()
  180. for _ in range(num_nodes):
  181. cluster.add_node(num_cpus=num_cpus, object_store_memory=object_store_memory)
  182. cluster.wait_for_nodes()
  183. return cluster
  184. def run(
  185. ray_address=None,
  186. object_store_memory=1e9,
  187. num_partitions=5,
  188. partition_size=200e6,
  189. num_nodes=None,
  190. num_cpus=8,
  191. no_streaming=False,
  192. use_wait=False,
  193. tracker=None,
  194. ):
  195. import time
  196. import numpy as np
  197. is_multi_node = num_nodes
  198. if ray_address:
  199. print("Connecting to a existing cluster...")
  200. ray.init(address=ray_address, ignore_reinit_error=True)
  201. elif is_multi_node:
  202. print("Emulating a cluster...")
  203. print(f"Num nodes: {num_nodes}")
  204. print(f"Num CPU per node: {num_cpus}")
  205. print(f"Object store memory per node: {object_store_memory}")
  206. cluster = build_cluster(num_nodes, num_cpus, object_store_memory)
  207. ray.init(address=cluster.address)
  208. else:
  209. print("Start a new cluster...")
  210. ray.init(num_cpus=num_cpus, object_store_memory=object_store_memory)
  211. partition_size = int(partition_size)
  212. num_partitions = num_partitions
  213. rows_per_partition = partition_size // (8 * 2)
  214. if tracker is None:
  215. tracker = _StatusTracker.remote()
  216. use_wait = use_wait
  217. def input_reader(i: PartitionID) -> Iterable[InType]:
  218. for _ in range(num_partitions):
  219. yield np.ones((rows_per_partition // num_partitions, 2), dtype=np.int64)
  220. def output_writer(i: PartitionID, shuffle_inputs: List[ObjectRef]) -> OutType:
  221. total = 0
  222. if not use_wait:
  223. for obj_ref in shuffle_inputs:
  224. arr = ray.get(obj_ref)
  225. total += arr.size * arr.itemsize
  226. else:
  227. while shuffle_inputs:
  228. [ready], shuffle_inputs = ray.wait(shuffle_inputs, num_returns=1)
  229. arr = ray.get(ready)
  230. total += arr.size * arr.itemsize
  231. return total
  232. def output_writer_non_streaming(
  233. i: PartitionID, shuffle_inputs: List[Any]
  234. ) -> OutType:
  235. total = 0
  236. for arr in shuffle_inputs:
  237. total += arr.size * arr.itemsize
  238. return total
  239. if no_streaming:
  240. output_writer_callable = output_writer_non_streaming
  241. object_store_writer = ObjectStoreWriterNonStreaming
  242. else:
  243. object_store_writer = ObjectStoreWriter
  244. output_writer_callable = output_writer
  245. start = time.time()
  246. output_sizes = simple_shuffle(
  247. input_reader=input_reader,
  248. input_num_partitions=num_partitions,
  249. output_num_partitions=num_partitions,
  250. output_writer=output_writer_callable,
  251. object_store_writer=object_store_writer,
  252. tracker=tracker,
  253. )
  254. delta = time.time() - start
  255. time.sleep(0.5)
  256. print()
  257. summary = None
  258. for i in range(5):
  259. try:
  260. summary = ray._private.internal_api.memory_summary(stats_only=True)
  261. except Exception:
  262. time.sleep(1)
  263. pass
  264. if summary:
  265. break
  266. print(summary)
  267. print()
  268. print(
  269. "Shuffled", int(sum(output_sizes) / (1024 * 1024)), "MiB in", delta, "seconds"
  270. )
  271. def main():
  272. import argparse
  273. parser = argparse.ArgumentParser()
  274. parser.add_argument("--ray-address", type=str, default=None)
  275. parser.add_argument("--object-store-memory", type=float, default=1e9)
  276. parser.add_argument("--num-partitions", type=int, default=5)
  277. parser.add_argument("--partition-size", type=float, default=200e6)
  278. parser.add_argument("--num-nodes", type=int, default=None)
  279. parser.add_argument("--num-cpus", type=int, default=8)
  280. parser.add_argument("--no-streaming", action="store_true", default=False)
  281. parser.add_argument("--use-wait", action="store_true", default=False)
  282. args = parser.parse_args()
  283. run(
  284. ray_address=args.ray_address,
  285. object_store_memory=args.object_store_memory,
  286. num_partitions=args.num_partitions,
  287. partition_size=args.partition_size,
  288. num_nodes=args.num_nodes,
  289. num_cpus=args.num_cpus,
  290. no_streaming=args.no_streaming,
  291. use_wait=args.use_wait,
  292. )
  293. if __name__ == "__main__":
  294. main()