"""A simple distributed shuffle implementation in Ray. This utility provides a `simple_shuffle` function that can be used to redistribute M input partitions into N output partitions. It does this with a single wave of shuffle map tasks followed by a single wave of shuffle reduce tasks. Each shuffle map task generates O(N) output objects, and each shuffle reduce task consumes O(M) input objects, for a total of O(N*M) objects. To try an example 10GB shuffle, run: $ python -m ray.experimental.shuffle \ --num-partitions=50 --partition-size=200e6 \ --object-store-memory=1e9 This will print out some statistics on the shuffle execution such as: --- Aggregate object store stats across all nodes --- Plasma memory usage 0 MiB, 0 objects, 0.0% full Spilled 9487 MiB, 2487 objects, avg write throughput 1023 MiB/s Restored 9487 MiB, 2487 objects, avg read throughput 1358 MiB/s Objects consumed by Ray tasks: 9537 MiB. Shuffled 9536 MiB in 16.579771757125854 seconds """ import time from typing import Any, Callable, Iterable, List, Tuple, Union import ray from ray import ObjectRef from ray.cluster_utils import Cluster # TODO(ekl) why doesn't TypeVar() deserialize properly in Ray? # The type produced by the input reader function. InType = Any # The type produced by the output writer function. OutType = Any # Integer identifying the partition number. PartitionID = int class ObjectStoreWriter: """This class is used to stream shuffle map outputs to the object store. It can be subclassed to optimize writing (e.g., batching together small records into larger objects). This will be performance critical if your input records are small (the example shuffle uses very large records, so the naive strategy works well). """ def __init__(self): self.results = [] def add(self, item: InType) -> None: """Queue a single item to be written to the object store. This base implementation immediately writes each given item to the object store as a standalone object. """ self.results.append(ray.put(item)) def finish(self) -> List[ObjectRef]: """Return list of object refs representing written items.""" return self.results class ObjectStoreWriterNonStreaming(ObjectStoreWriter): def __init__(self): self.results = [] def add(self, item: InType) -> None: self.results.append(item) def finish(self) -> List[Any]: return self.results def round_robin_partitioner( input_stream: Iterable[InType], num_partitions: int ) -> Iterable[Tuple[PartitionID, InType]]: """Round robin partitions items from the input reader. You can write custom partitioning functions for your use case. Args: input_stream: Iterator over items from the input reader. num_partitions: Number of output partitions. Yields: Tuples of (partition id, input item). """ i = 0 for item in input_stream: yield (i, item) i += 1 i %= num_partitions @ray.remote class _StatusTracker: def __init__(self): self.num_map = 0 self.num_reduce = 0 self.map_refs = [] self.reduce_refs = [] def register_objectrefs(self, map_refs, reduce_refs): self.map_refs = map_refs self.reduce_refs = reduce_refs def get_progress(self): if self.map_refs: ready, self.map_refs = ray.wait( self.map_refs, timeout=1, num_returns=len(self.map_refs), fetch_local=False, ) self.num_map += len(ready) elif self.reduce_refs: ready, self.reduce_refs = ray.wait( self.reduce_refs, timeout=1, num_returns=len(self.reduce_refs), fetch_local=False, ) self.num_reduce += len(ready) return self.num_map, self.num_reduce def render_progress_bar(tracker, input_num_partitions, output_num_partitions): from tqdm import tqdm num_map = 0 num_reduce = 0 map_bar = tqdm(total=input_num_partitions, position=0) map_bar.set_description("Map Progress.") reduce_bar = tqdm(total=output_num_partitions, position=1) reduce_bar.set_description("Reduce Progress.") while num_map < input_num_partitions or num_reduce < output_num_partitions: new_num_map, new_num_reduce = ray.get(tracker.get_progress.remote()) map_bar.update(new_num_map - num_map) reduce_bar.update(new_num_reduce - num_reduce) num_map = new_num_map num_reduce = new_num_reduce time.sleep(0.1) map_bar.close() reduce_bar.close() def simple_shuffle( *, input_reader: Callable[[PartitionID], Iterable[InType]], input_num_partitions: int, output_num_partitions: int, output_writer: Callable[[PartitionID, List[Union[ObjectRef, Any]]], OutType], partitioner: Callable[ [Iterable[InType], int], Iterable[PartitionID] ] = round_robin_partitioner, object_store_writer: ObjectStoreWriter = ObjectStoreWriter, tracker: _StatusTracker = None, streaming: bool = True, ) -> List[OutType]: """Simple distributed shuffle in Ray. Args: input_reader: Function that generates the input items for a partition (e.g., data records). input_num_partitions: The number of input partitions. output_num_partitions: The desired number of output partitions. output_writer: Function that consumes a iterator of items for a given output partition. It returns a single value that will be collected across all output partitions. partitioner: Partitioning function to use. Defaults to round-robin partitioning of input items. object_store_writer: Class used to write input items to the object store in an efficient way. Defaults to a naive implementation that writes each input record as one object. tracker: Tracker actor that is used to display the progress bar. streaming: Whether or not if the shuffle will be streaming. Returns: List of outputs from the output writers. """ @ray.remote(num_returns=output_num_partitions) def shuffle_map(i: PartitionID) -> List[List[Union[Any, ObjectRef]]]: writers = [object_store_writer() for _ in range(output_num_partitions)] for out_i, item in partitioner(input_reader(i), output_num_partitions): writers[out_i].add(item) return [c.finish() for c in writers] @ray.remote def shuffle_reduce( i: PartitionID, *mapper_outputs: List[List[Union[Any, ObjectRef]]] ) -> OutType: input_objects = [] assert len(mapper_outputs) == input_num_partitions for obj_refs in mapper_outputs: for obj_ref in obj_refs: input_objects.append(obj_ref) return output_writer(i, input_objects) shuffle_map_out = [shuffle_map.remote(i) for i in range(input_num_partitions)] shuffle_reduce_out = [ shuffle_reduce.remote( j, *[shuffle_map_out[i][j] for i in range(input_num_partitions)] ) for j in range(output_num_partitions) ] if tracker: tracker.register_objectrefs.remote( [map_out[0] for map_out in shuffle_map_out], shuffle_reduce_out ) render_progress_bar(tracker, input_num_partitions, output_num_partitions) return ray.get(shuffle_reduce_out) def build_cluster(num_nodes, num_cpus, object_store_memory): cluster = Cluster() for _ in range(num_nodes): cluster.add_node(num_cpus=num_cpus, object_store_memory=object_store_memory) cluster.wait_for_nodes() return cluster def run( ray_address=None, object_store_memory=1e9, num_partitions=5, partition_size=200e6, num_nodes=None, num_cpus=8, no_streaming=False, use_wait=False, tracker=None, ): import time import numpy as np is_multi_node = num_nodes if ray_address: print("Connecting to a existing cluster...") ray.init(address=ray_address, ignore_reinit_error=True) elif is_multi_node: print("Emulating a cluster...") print(f"Num nodes: {num_nodes}") print(f"Num CPU per node: {num_cpus}") print(f"Object store memory per node: {object_store_memory}") cluster = build_cluster(num_nodes, num_cpus, object_store_memory) ray.init(address=cluster.address) else: print("Start a new cluster...") ray.init(num_cpus=num_cpus, object_store_memory=object_store_memory) partition_size = int(partition_size) num_partitions = num_partitions rows_per_partition = partition_size // (8 * 2) if tracker is None: tracker = _StatusTracker.remote() use_wait = use_wait def input_reader(i: PartitionID) -> Iterable[InType]: for _ in range(num_partitions): yield np.ones((rows_per_partition // num_partitions, 2), dtype=np.int64) def output_writer(i: PartitionID, shuffle_inputs: List[ObjectRef]) -> OutType: total = 0 if not use_wait: for obj_ref in shuffle_inputs: arr = ray.get(obj_ref) total += arr.size * arr.itemsize else: while shuffle_inputs: [ready], shuffle_inputs = ray.wait(shuffle_inputs, num_returns=1) arr = ray.get(ready) total += arr.size * arr.itemsize return total def output_writer_non_streaming( i: PartitionID, shuffle_inputs: List[Any] ) -> OutType: total = 0 for arr in shuffle_inputs: total += arr.size * arr.itemsize return total if no_streaming: output_writer_callable = output_writer_non_streaming object_store_writer = ObjectStoreWriterNonStreaming else: object_store_writer = ObjectStoreWriter output_writer_callable = output_writer start = time.time() output_sizes = simple_shuffle( input_reader=input_reader, input_num_partitions=num_partitions, output_num_partitions=num_partitions, output_writer=output_writer_callable, object_store_writer=object_store_writer, tracker=tracker, ) delta = time.time() - start time.sleep(0.5) print() summary = None for i in range(5): try: summary = ray._private.internal_api.memory_summary(stats_only=True) except Exception: time.sleep(1) pass if summary: break print(summary) print() print( "Shuffled", int(sum(output_sizes) / (1024 * 1024)), "MiB in", delta, "seconds" ) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--ray-address", type=str, default=None) parser.add_argument("--object-store-memory", type=float, default=1e9) parser.add_argument("--num-partitions", type=int, default=5) parser.add_argument("--partition-size", type=float, default=200e6) parser.add_argument("--num-nodes", type=int, default=None) parser.add_argument("--num-cpus", type=int, default=8) parser.add_argument("--no-streaming", action="store_true", default=False) parser.add_argument("--use-wait", action="store_true", default=False) args = parser.parse_args() run( ray_address=args.ray_address, object_store_memory=args.object_store_memory, num_partitions=args.num_partitions, partition_size=args.partition_size, num_nodes=args.num_nodes, num_cpus=args.num_cpus, no_streaming=args.no_streaming, use_wait=args.use_wait, ) if __name__ == "__main__": main()