| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- """A simple distributed shuffle implementation in Ray.
- This utility provides a `simple_shuffle` function that can be used to
- redistribute M input partitions into N output partitions. It does this with
- a single wave of shuffle map tasks followed by a single wave of shuffle reduce
- tasks. Each shuffle map task generates O(N) output objects, and each shuffle
- reduce task consumes O(M) input objects, for a total of O(N*M) objects.
- To try an example 10GB shuffle, run:
- $ python -m ray.experimental.shuffle \
- --num-partitions=50 --partition-size=200e6 \
- --object-store-memory=1e9
- This will print out some statistics on the shuffle execution such as:
- --- Aggregate object store stats across all nodes ---
- Plasma memory usage 0 MiB, 0 objects, 0.0% full
- Spilled 9487 MiB, 2487 objects, avg write throughput 1023 MiB/s
- Restored 9487 MiB, 2487 objects, avg read throughput 1358 MiB/s
- Objects consumed by Ray tasks: 9537 MiB.
- Shuffled 9536 MiB in 16.579771757125854 seconds
- """
- import time
- from typing import Any, Callable, Iterable, List, Tuple, Union
- import ray
- from ray import ObjectRef
- from ray.cluster_utils import Cluster
- # TODO(ekl) why doesn't TypeVar() deserialize properly in Ray?
- # The type produced by the input reader function.
- InType = Any
- # The type produced by the output writer function.
- OutType = Any
- # Integer identifying the partition number.
- PartitionID = int
- class ObjectStoreWriter:
- """This class is used to stream shuffle map outputs to the object store.
- It can be subclassed to optimize writing (e.g., batching together small
- records into larger objects). This will be performance critical if your
- input records are small (the example shuffle uses very large records, so
- the naive strategy works well).
- """
- def __init__(self):
- self.results = []
- def add(self, item: InType) -> None:
- """Queue a single item to be written to the object store.
- This base implementation immediately writes each given item to the
- object store as a standalone object.
- """
- self.results.append(ray.put(item))
- def finish(self) -> List[ObjectRef]:
- """Return list of object refs representing written items."""
- return self.results
- class ObjectStoreWriterNonStreaming(ObjectStoreWriter):
- def __init__(self):
- self.results = []
- def add(self, item: InType) -> None:
- self.results.append(item)
- def finish(self) -> List[Any]:
- return self.results
- def round_robin_partitioner(
- input_stream: Iterable[InType], num_partitions: int
- ) -> Iterable[Tuple[PartitionID, InType]]:
- """Round robin partitions items from the input reader.
- You can write custom partitioning functions for your use case.
- Args:
- input_stream: Iterator over items from the input reader.
- num_partitions: Number of output partitions.
- Yields:
- Tuples of (partition id, input item).
- """
- i = 0
- for item in input_stream:
- yield (i, item)
- i += 1
- i %= num_partitions
- @ray.remote
- class _StatusTracker:
- def __init__(self):
- self.num_map = 0
- self.num_reduce = 0
- self.map_refs = []
- self.reduce_refs = []
- def register_objectrefs(self, map_refs, reduce_refs):
- self.map_refs = map_refs
- self.reduce_refs = reduce_refs
- def get_progress(self):
- if self.map_refs:
- ready, self.map_refs = ray.wait(
- self.map_refs,
- timeout=1,
- num_returns=len(self.map_refs),
- fetch_local=False,
- )
- self.num_map += len(ready)
- elif self.reduce_refs:
- ready, self.reduce_refs = ray.wait(
- self.reduce_refs,
- timeout=1,
- num_returns=len(self.reduce_refs),
- fetch_local=False,
- )
- self.num_reduce += len(ready)
- return self.num_map, self.num_reduce
- def render_progress_bar(tracker, input_num_partitions, output_num_partitions):
- from tqdm import tqdm
- num_map = 0
- num_reduce = 0
- map_bar = tqdm(total=input_num_partitions, position=0)
- map_bar.set_description("Map Progress.")
- reduce_bar = tqdm(total=output_num_partitions, position=1)
- reduce_bar.set_description("Reduce Progress.")
- while num_map < input_num_partitions or num_reduce < output_num_partitions:
- new_num_map, new_num_reduce = ray.get(tracker.get_progress.remote())
- map_bar.update(new_num_map - num_map)
- reduce_bar.update(new_num_reduce - num_reduce)
- num_map = new_num_map
- num_reduce = new_num_reduce
- time.sleep(0.1)
- map_bar.close()
- reduce_bar.close()
- def simple_shuffle(
- *,
- input_reader: Callable[[PartitionID], Iterable[InType]],
- input_num_partitions: int,
- output_num_partitions: int,
- output_writer: Callable[[PartitionID, List[Union[ObjectRef, Any]]], OutType],
- partitioner: Callable[
- [Iterable[InType], int], Iterable[PartitionID]
- ] = round_robin_partitioner,
- object_store_writer: ObjectStoreWriter = ObjectStoreWriter,
- tracker: _StatusTracker = None,
- streaming: bool = True,
- ) -> List[OutType]:
- """Simple distributed shuffle in Ray.
- Args:
- input_reader: Function that generates the input items for a
- partition (e.g., data records).
- input_num_partitions: The number of input partitions.
- output_num_partitions: The desired number of output partitions.
- output_writer: Function that consumes a iterator of items for a
- given output partition. It returns a single value that will be
- collected across all output partitions.
- partitioner: Partitioning function to use. Defaults to round-robin
- partitioning of input items.
- object_store_writer: Class used to write input items to the
- object store in an efficient way. Defaults to a naive
- implementation that writes each input record as one object.
- tracker: Tracker actor that is used to display the progress bar.
- streaming: Whether or not if the shuffle will be streaming.
- Returns:
- List of outputs from the output writers.
- """
- @ray.remote(num_returns=output_num_partitions)
- def shuffle_map(i: PartitionID) -> List[List[Union[Any, ObjectRef]]]:
- writers = [object_store_writer() for _ in range(output_num_partitions)]
- for out_i, item in partitioner(input_reader(i), output_num_partitions):
- writers[out_i].add(item)
- return [c.finish() for c in writers]
- @ray.remote
- def shuffle_reduce(
- i: PartitionID, *mapper_outputs: List[List[Union[Any, ObjectRef]]]
- ) -> OutType:
- input_objects = []
- assert len(mapper_outputs) == input_num_partitions
- for obj_refs in mapper_outputs:
- for obj_ref in obj_refs:
- input_objects.append(obj_ref)
- return output_writer(i, input_objects)
- shuffle_map_out = [shuffle_map.remote(i) for i in range(input_num_partitions)]
- shuffle_reduce_out = [
- shuffle_reduce.remote(
- j, *[shuffle_map_out[i][j] for i in range(input_num_partitions)]
- )
- for j in range(output_num_partitions)
- ]
- if tracker:
- tracker.register_objectrefs.remote(
- [map_out[0] for map_out in shuffle_map_out], shuffle_reduce_out
- )
- render_progress_bar(tracker, input_num_partitions, output_num_partitions)
- return ray.get(shuffle_reduce_out)
- def build_cluster(num_nodes, num_cpus, object_store_memory):
- cluster = Cluster()
- for _ in range(num_nodes):
- cluster.add_node(num_cpus=num_cpus, object_store_memory=object_store_memory)
- cluster.wait_for_nodes()
- return cluster
- def run(
- ray_address=None,
- object_store_memory=1e9,
- num_partitions=5,
- partition_size=200e6,
- num_nodes=None,
- num_cpus=8,
- no_streaming=False,
- use_wait=False,
- tracker=None,
- ):
- import time
- import numpy as np
- is_multi_node = num_nodes
- if ray_address:
- print("Connecting to a existing cluster...")
- ray.init(address=ray_address, ignore_reinit_error=True)
- elif is_multi_node:
- print("Emulating a cluster...")
- print(f"Num nodes: {num_nodes}")
- print(f"Num CPU per node: {num_cpus}")
- print(f"Object store memory per node: {object_store_memory}")
- cluster = build_cluster(num_nodes, num_cpus, object_store_memory)
- ray.init(address=cluster.address)
- else:
- print("Start a new cluster...")
- ray.init(num_cpus=num_cpus, object_store_memory=object_store_memory)
- partition_size = int(partition_size)
- num_partitions = num_partitions
- rows_per_partition = partition_size // (8 * 2)
- if tracker is None:
- tracker = _StatusTracker.remote()
- use_wait = use_wait
- def input_reader(i: PartitionID) -> Iterable[InType]:
- for _ in range(num_partitions):
- yield np.ones((rows_per_partition // num_partitions, 2), dtype=np.int64)
- def output_writer(i: PartitionID, shuffle_inputs: List[ObjectRef]) -> OutType:
- total = 0
- if not use_wait:
- for obj_ref in shuffle_inputs:
- arr = ray.get(obj_ref)
- total += arr.size * arr.itemsize
- else:
- while shuffle_inputs:
- [ready], shuffle_inputs = ray.wait(shuffle_inputs, num_returns=1)
- arr = ray.get(ready)
- total += arr.size * arr.itemsize
- return total
- def output_writer_non_streaming(
- i: PartitionID, shuffle_inputs: List[Any]
- ) -> OutType:
- total = 0
- for arr in shuffle_inputs:
- total += arr.size * arr.itemsize
- return total
- if no_streaming:
- output_writer_callable = output_writer_non_streaming
- object_store_writer = ObjectStoreWriterNonStreaming
- else:
- object_store_writer = ObjectStoreWriter
- output_writer_callable = output_writer
- start = time.time()
- output_sizes = simple_shuffle(
- input_reader=input_reader,
- input_num_partitions=num_partitions,
- output_num_partitions=num_partitions,
- output_writer=output_writer_callable,
- object_store_writer=object_store_writer,
- tracker=tracker,
- )
- delta = time.time() - start
- time.sleep(0.5)
- print()
- summary = None
- for i in range(5):
- try:
- summary = ray._private.internal_api.memory_summary(stats_only=True)
- except Exception:
- time.sleep(1)
- pass
- if summary:
- break
- print(summary)
- print()
- print(
- "Shuffled", int(sum(output_sizes) / (1024 * 1024)), "MiB in", delta, "seconds"
- )
- def main():
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--ray-address", type=str, default=None)
- parser.add_argument("--object-store-memory", type=float, default=1e9)
- parser.add_argument("--num-partitions", type=int, default=5)
- parser.add_argument("--partition-size", type=float, default=200e6)
- parser.add_argument("--num-nodes", type=int, default=None)
- parser.add_argument("--num-cpus", type=int, default=8)
- parser.add_argument("--no-streaming", action="store_true", default=False)
- parser.add_argument("--use-wait", action="store_true", default=False)
- args = parser.parse_args()
- run(
- ray_address=args.ray_address,
- object_store_memory=args.object_store_memory,
- num_partitions=args.num_partitions,
- partition_size=args.partition_size,
- num_nodes=args.num_nodes,
- num_cpus=args.num_cpus,
- no_streaming=args.no_streaming,
- use_wait=args.use_wait,
- )
- if __name__ == "__main__":
- main()
|