| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- import bisect
- import logging
- import random
- import time
- from collections import defaultdict
- from typing import TYPE_CHECKING, Any, List, Optional
- import numpy as np
- import ray
- from ray.data._internal.execution.interfaces.ref_bundle import (
- _ref_bundles_iterator_to_block_refs_list,
- )
- from ray.data._internal.remote_fn import cached_remote_fn
- from ray.data.block import BlockAccessor
- from ray.data.context import DataContext
- from ray.types import ObjectRef
- from ray.util.annotations import PublicAPI
- try:
- import pyarrow as pa
- except ImportError:
- pa = None
- if TYPE_CHECKING:
- from ray.data.dataset import Dataset
- logger = logging.getLogger(__name__)
- @PublicAPI(stability="alpha")
- class RandomAccessDataset:
- """A class that provides distributed, random access to a Dataset.
- See: ``Dataset.to_random_access_dataset()``.
- """
- def __init__(
- self,
- ds: "Dataset",
- key: str,
- num_workers: int,
- ):
- """Construct a RandomAccessDataset (internal API).
- The constructor is a private API. Use ``ds.to_random_access_dataset()``
- to construct a RandomAccessDataset.
- """
- schema = ds.schema(fetch_if_missing=True)
- if schema is None or isinstance(schema, type):
- raise ValueError("RandomAccessDataset only supports Arrow-format blocks.")
- start = time.perf_counter()
- logger.info("[setup] Indexing dataset by sort key.")
- sorted_ds = ds.sort(key)
- get_bounds = cached_remote_fn(_get_bounds)
- bundles = sorted_ds.iter_internal_ref_bundles()
- blocks = _ref_bundles_iterator_to_block_refs_list(bundles)
- logger.info("[setup] Computing block range bounds.")
- bounds = ray.get([get_bounds.remote(b, key) for b in blocks])
- self._non_empty_blocks = []
- self._lower_bound = None
- self._upper_bounds = []
- for i, b in enumerate(bounds):
- if b:
- self._non_empty_blocks.append(blocks[i])
- if self._lower_bound is None:
- self._lower_bound = b[0]
- self._upper_bounds.append(b[1])
- logger.info("[setup] Creating {} random access workers.".format(num_workers))
- ctx = DataContext.get_current()
- scheduling_strategy = ctx.scheduling_strategy
- self._workers = [
- _RandomAccessWorker.options(scheduling_strategy=scheduling_strategy).remote(
- key
- )
- for _ in range(num_workers)
- ]
- (
- self._block_to_workers_map,
- self._worker_to_blocks_map,
- ) = self._compute_block_to_worker_assignments()
- logger.info(
- "[setup] Worker to blocks assignment: {}".format(self._worker_to_blocks_map)
- )
- ray.get(
- [
- w.assign_blocks.remote(
- {
- i: self._non_empty_blocks[i]
- for i in self._worker_to_blocks_map[w]
- }
- )
- for w in self._workers
- ]
- )
- logger.info("[setup] Finished assigning blocks to workers.")
- self._build_time = time.perf_counter() - start
- def _compute_block_to_worker_assignments(self):
- # Return values.
- block_to_workers: dict[int, List["ray.ActorHandle"]] = defaultdict(list)
- worker_to_blocks: dict["ray.ActorHandle", List[int]] = defaultdict(list)
- # Aux data structures.
- loc_to_workers: dict[str, List["ray.ActorHandle"]] = defaultdict(list)
- locs = ray.get([w.ping.remote() for w in self._workers])
- for i, loc in enumerate(locs):
- loc_to_workers[loc].append(self._workers[i])
- block_locs = ray.experimental.get_object_locations(self._non_empty_blocks)
- # First, try to assign all blocks to all workers at its location.
- for block_idx, block in enumerate(self._non_empty_blocks):
- block_info = block_locs[block]
- locs = block_info.get("node_ids", [])
- for loc in locs:
- for worker in loc_to_workers[loc]:
- block_to_workers[block_idx].append(worker)
- worker_to_blocks[worker].append(block_idx)
- # Randomly assign any leftover blocks to at least one worker.
- # TODO: the load balancing here could be improved.
- for block_idx, block in enumerate(self._non_empty_blocks):
- if len(block_to_workers[block_idx]) == 0:
- worker = random.choice(self._workers)
- block_to_workers[block_idx].append(worker)
- worker_to_blocks[worker].append(block_idx)
- return block_to_workers, worker_to_blocks
- def get_async(self, key: Any) -> ObjectRef[Any]:
- """Asynchronously finds the record for a single key.
- Args:
- key: The key of the record to find.
- Returns:
- ObjectRef containing the record (in pydict form), or None if not found.
- """
- block_index = self._find_le(key)
- if block_index is None:
- return ray.put(None)
- return self._worker_for(block_index).get.remote(block_index, key)
- def multiget(self, keys: List[Any]) -> List[Optional[Any]]:
- """Synchronously find the records for a list of keys.
- Args:
- keys: List of keys to find the records for.
- Returns:
- List of found records (in pydict form), or None for missing records.
- """
- batches = defaultdict(list)
- for k in keys:
- batches[self._find_le(k)].append(k)
- futures = {}
- for index, keybatch in batches.items():
- if index is None:
- continue
- fut = self._worker_for(index).multiget.remote(
- [index] * len(keybatch), keybatch
- )
- futures[index] = fut
- results = {}
- for i, fut in futures.items():
- keybatch = batches[i]
- values = ray.get(fut)
- for k, v in zip(keybatch, values):
- results[k] = v
- return [results.get(k) for k in keys]
- def stats(self) -> str:
- """Returns a string containing access timing information."""
- stats = ray.get([w.stats.remote() for w in self._workers])
- total_time = sum(s["total_time"] for s in stats)
- accesses = [s["num_accesses"] for s in stats]
- blocks = [s["num_blocks"] for s in stats]
- msg = "RandomAccessDataset:\n"
- msg += "- Build time: {}s\n".format(round(self._build_time, 2))
- msg += "- Num workers: {}\n".format(len(stats))
- msg += "- Blocks per worker: {} min, {} max, {} mean\n".format(
- min(blocks), max(blocks), int(sum(blocks) / len(blocks))
- )
- msg += "- Accesses per worker: {} min, {} max, {} mean\n".format(
- min(accesses), max(accesses), int(sum(accesses) / len(accesses))
- )
- msg += "- Mean access time: {}us\n".format(
- int(total_time / (1 + sum(accesses)) * 1e6)
- )
- return msg
- def _worker_for(self, block_index: int):
- return random.choice(self._block_to_workers_map[block_index])
- def _find_le(self, x: Any) -> int:
- i = bisect.bisect_left(self._upper_bounds, x)
- if i >= len(self._upper_bounds) or x < self._lower_bound:
- return None
- return i
- @ray.remote(num_cpus=0)
- class _RandomAccessWorker:
- def __init__(self, key_field):
- self.blocks = None
- self.key_field = key_field
- self.num_accesses = 0
- self.total_time = 0
- def assign_blocks(self, block_ref_dict):
- self.blocks = {k: ray.get(ref) for k, ref in block_ref_dict.items()}
- def get(self, block_index, key):
- start = time.perf_counter()
- result = self._get(block_index, key)
- self.total_time += time.perf_counter() - start
- self.num_accesses += 1
- return result
- def multiget(self, block_indices, keys):
- start = time.perf_counter()
- block = self.blocks[block_indices[0]]
- if len(set(block_indices)) == 1 and isinstance(
- self.blocks[block_indices[0]], pa.Table
- ):
- # Fast path: use np.searchsorted for vectorized search on a single block.
- # This is ~3x faster than the naive case.
- block = self.blocks[block_indices[0]]
- col = block[self.key_field]
- indices = np.searchsorted(col, keys)
- acc = BlockAccessor.for_block(block)
- result = [
- acc._get_row(i) if k1.as_py() == k2 else None
- for i, k1, k2 in zip(indices, col.take(indices), keys)
- ]
- else:
- result = [self._get(i, k) for i, k in zip(block_indices, keys)]
- self.total_time += time.perf_counter() - start
- self.num_accesses += 1
- return result
- def ping(self):
- return ray.get_runtime_context().get_node_id()
- def stats(self) -> dict:
- return {
- "num_blocks": len(self.blocks),
- "num_accesses": self.num_accesses,
- "total_time": self.total_time,
- }
- def _get(self, block_index, key):
- if block_index is None:
- return None
- block = self.blocks[block_index]
- column = block[self.key_field]
- if isinstance(block, pa.Table):
- column = _ArrowListWrapper(column)
- i = _binary_search_find(column, key)
- if i is None:
- return None
- acc = BlockAccessor.for_block(block)
- return acc._get_row(i)
- def _binary_search_find(column, x):
- i = bisect.bisect_left(column, x)
- if i != len(column) and column[i] == x:
- return i
- return None
- class _ArrowListWrapper:
- def __init__(self, arrow_col):
- self.arrow_col = arrow_col
- def __getitem__(self, i):
- return self.arrow_col[i].as_py()
- def __len__(self):
- return len(self.arrow_col)
- def _get_bounds(block, key):
- if len(block) == 0:
- return None
- b = (block[key][0], block[key][len(block) - 1])
- if isinstance(block, pa.Table):
- b = (b[0].as_py(), b[1].as_py())
- return b
|