random_access_dataset.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import bisect
  2. import logging
  3. import random
  4. import time
  5. from collections import defaultdict
  6. from typing import TYPE_CHECKING, Any, List, Optional
  7. import numpy as np
  8. import ray
  9. from ray.data._internal.execution.interfaces.ref_bundle import (
  10. _ref_bundles_iterator_to_block_refs_list,
  11. )
  12. from ray.data._internal.remote_fn import cached_remote_fn
  13. from ray.data.block import BlockAccessor
  14. from ray.data.context import DataContext
  15. from ray.types import ObjectRef
  16. from ray.util.annotations import PublicAPI
  17. try:
  18. import pyarrow as pa
  19. except ImportError:
  20. pa = None
  21. if TYPE_CHECKING:
  22. from ray.data.dataset import Dataset
  23. logger = logging.getLogger(__name__)
  24. @PublicAPI(stability="alpha")
  25. class RandomAccessDataset:
  26. """A class that provides distributed, random access to a Dataset.
  27. See: ``Dataset.to_random_access_dataset()``.
  28. """
  29. def __init__(
  30. self,
  31. ds: "Dataset",
  32. key: str,
  33. num_workers: int,
  34. ):
  35. """Construct a RandomAccessDataset (internal API).
  36. The constructor is a private API. Use ``ds.to_random_access_dataset()``
  37. to construct a RandomAccessDataset.
  38. """
  39. schema = ds.schema(fetch_if_missing=True)
  40. if schema is None or isinstance(schema, type):
  41. raise ValueError("RandomAccessDataset only supports Arrow-format blocks.")
  42. start = time.perf_counter()
  43. logger.info("[setup] Indexing dataset by sort key.")
  44. sorted_ds = ds.sort(key)
  45. get_bounds = cached_remote_fn(_get_bounds)
  46. bundles = sorted_ds.iter_internal_ref_bundles()
  47. blocks = _ref_bundles_iterator_to_block_refs_list(bundles)
  48. logger.info("[setup] Computing block range bounds.")
  49. bounds = ray.get([get_bounds.remote(b, key) for b in blocks])
  50. self._non_empty_blocks = []
  51. self._lower_bound = None
  52. self._upper_bounds = []
  53. for i, b in enumerate(bounds):
  54. if b:
  55. self._non_empty_blocks.append(blocks[i])
  56. if self._lower_bound is None:
  57. self._lower_bound = b[0]
  58. self._upper_bounds.append(b[1])
  59. logger.info("[setup] Creating {} random access workers.".format(num_workers))
  60. ctx = DataContext.get_current()
  61. scheduling_strategy = ctx.scheduling_strategy
  62. self._workers = [
  63. _RandomAccessWorker.options(scheduling_strategy=scheduling_strategy).remote(
  64. key
  65. )
  66. for _ in range(num_workers)
  67. ]
  68. (
  69. self._block_to_workers_map,
  70. self._worker_to_blocks_map,
  71. ) = self._compute_block_to_worker_assignments()
  72. logger.info(
  73. "[setup] Worker to blocks assignment: {}".format(self._worker_to_blocks_map)
  74. )
  75. ray.get(
  76. [
  77. w.assign_blocks.remote(
  78. {
  79. i: self._non_empty_blocks[i]
  80. for i in self._worker_to_blocks_map[w]
  81. }
  82. )
  83. for w in self._workers
  84. ]
  85. )
  86. logger.info("[setup] Finished assigning blocks to workers.")
  87. self._build_time = time.perf_counter() - start
  88. def _compute_block_to_worker_assignments(self):
  89. # Return values.
  90. block_to_workers: dict[int, List["ray.ActorHandle"]] = defaultdict(list)
  91. worker_to_blocks: dict["ray.ActorHandle", List[int]] = defaultdict(list)
  92. # Aux data structures.
  93. loc_to_workers: dict[str, List["ray.ActorHandle"]] = defaultdict(list)
  94. locs = ray.get([w.ping.remote() for w in self._workers])
  95. for i, loc in enumerate(locs):
  96. loc_to_workers[loc].append(self._workers[i])
  97. block_locs = ray.experimental.get_object_locations(self._non_empty_blocks)
  98. # First, try to assign all blocks to all workers at its location.
  99. for block_idx, block in enumerate(self._non_empty_blocks):
  100. block_info = block_locs[block]
  101. locs = block_info.get("node_ids", [])
  102. for loc in locs:
  103. for worker in loc_to_workers[loc]:
  104. block_to_workers[block_idx].append(worker)
  105. worker_to_blocks[worker].append(block_idx)
  106. # Randomly assign any leftover blocks to at least one worker.
  107. # TODO: the load balancing here could be improved.
  108. for block_idx, block in enumerate(self._non_empty_blocks):
  109. if len(block_to_workers[block_idx]) == 0:
  110. worker = random.choice(self._workers)
  111. block_to_workers[block_idx].append(worker)
  112. worker_to_blocks[worker].append(block_idx)
  113. return block_to_workers, worker_to_blocks
  114. def get_async(self, key: Any) -> ObjectRef[Any]:
  115. """Asynchronously finds the record for a single key.
  116. Args:
  117. key: The key of the record to find.
  118. Returns:
  119. ObjectRef containing the record (in pydict form), or None if not found.
  120. """
  121. block_index = self._find_le(key)
  122. if block_index is None:
  123. return ray.put(None)
  124. return self._worker_for(block_index).get.remote(block_index, key)
  125. def multiget(self, keys: List[Any]) -> List[Optional[Any]]:
  126. """Synchronously find the records for a list of keys.
  127. Args:
  128. keys: List of keys to find the records for.
  129. Returns:
  130. List of found records (in pydict form), or None for missing records.
  131. """
  132. batches = defaultdict(list)
  133. for k in keys:
  134. batches[self._find_le(k)].append(k)
  135. futures = {}
  136. for index, keybatch in batches.items():
  137. if index is None:
  138. continue
  139. fut = self._worker_for(index).multiget.remote(
  140. [index] * len(keybatch), keybatch
  141. )
  142. futures[index] = fut
  143. results = {}
  144. for i, fut in futures.items():
  145. keybatch = batches[i]
  146. values = ray.get(fut)
  147. for k, v in zip(keybatch, values):
  148. results[k] = v
  149. return [results.get(k) for k in keys]
  150. def stats(self) -> str:
  151. """Returns a string containing access timing information."""
  152. stats = ray.get([w.stats.remote() for w in self._workers])
  153. total_time = sum(s["total_time"] for s in stats)
  154. accesses = [s["num_accesses"] for s in stats]
  155. blocks = [s["num_blocks"] for s in stats]
  156. msg = "RandomAccessDataset:\n"
  157. msg += "- Build time: {}s\n".format(round(self._build_time, 2))
  158. msg += "- Num workers: {}\n".format(len(stats))
  159. msg += "- Blocks per worker: {} min, {} max, {} mean\n".format(
  160. min(blocks), max(blocks), int(sum(blocks) / len(blocks))
  161. )
  162. msg += "- Accesses per worker: {} min, {} max, {} mean\n".format(
  163. min(accesses), max(accesses), int(sum(accesses) / len(accesses))
  164. )
  165. msg += "- Mean access time: {}us\n".format(
  166. int(total_time / (1 + sum(accesses)) * 1e6)
  167. )
  168. return msg
  169. def _worker_for(self, block_index: int):
  170. return random.choice(self._block_to_workers_map[block_index])
  171. def _find_le(self, x: Any) -> int:
  172. i = bisect.bisect_left(self._upper_bounds, x)
  173. if i >= len(self._upper_bounds) or x < self._lower_bound:
  174. return None
  175. return i
  176. @ray.remote(num_cpus=0)
  177. class _RandomAccessWorker:
  178. def __init__(self, key_field):
  179. self.blocks = None
  180. self.key_field = key_field
  181. self.num_accesses = 0
  182. self.total_time = 0
  183. def assign_blocks(self, block_ref_dict):
  184. self.blocks = {k: ray.get(ref) for k, ref in block_ref_dict.items()}
  185. def get(self, block_index, key):
  186. start = time.perf_counter()
  187. result = self._get(block_index, key)
  188. self.total_time += time.perf_counter() - start
  189. self.num_accesses += 1
  190. return result
  191. def multiget(self, block_indices, keys):
  192. start = time.perf_counter()
  193. block = self.blocks[block_indices[0]]
  194. if len(set(block_indices)) == 1 and isinstance(
  195. self.blocks[block_indices[0]], pa.Table
  196. ):
  197. # Fast path: use np.searchsorted for vectorized search on a single block.
  198. # This is ~3x faster than the naive case.
  199. block = self.blocks[block_indices[0]]
  200. col = block[self.key_field]
  201. indices = np.searchsorted(col, keys)
  202. acc = BlockAccessor.for_block(block)
  203. result = [
  204. acc._get_row(i) if k1.as_py() == k2 else None
  205. for i, k1, k2 in zip(indices, col.take(indices), keys)
  206. ]
  207. else:
  208. result = [self._get(i, k) for i, k in zip(block_indices, keys)]
  209. self.total_time += time.perf_counter() - start
  210. self.num_accesses += 1
  211. return result
  212. def ping(self):
  213. return ray.get_runtime_context().get_node_id()
  214. def stats(self) -> dict:
  215. return {
  216. "num_blocks": len(self.blocks),
  217. "num_accesses": self.num_accesses,
  218. "total_time": self.total_time,
  219. }
  220. def _get(self, block_index, key):
  221. if block_index is None:
  222. return None
  223. block = self.blocks[block_index]
  224. column = block[self.key_field]
  225. if isinstance(block, pa.Table):
  226. column = _ArrowListWrapper(column)
  227. i = _binary_search_find(column, key)
  228. if i is None:
  229. return None
  230. acc = BlockAccessor.for_block(block)
  231. return acc._get_row(i)
  232. def _binary_search_find(column, x):
  233. i = bisect.bisect_left(column, x)
  234. if i != len(column) and column[i] == x:
  235. return i
  236. return None
  237. class _ArrowListWrapper:
  238. def __init__(self, arrow_col):
  239. self.arrow_col = arrow_col
  240. def __getitem__(self, i):
  241. return self.arrow_col[i].as_py()
  242. def __len__(self):
  243. return len(self.arrow_col)
  244. def _get_bounds(block, key):
  245. if len(block) == 0:
  246. return None
  247. b = (block[key][0], block[key][len(block) - 1])
  248. if isinstance(block, pa.Table):
  249. b = (b[0].as_py(), b[1].as_py())
  250. return b