datasource.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. import copy
  2. from typing import TYPE_CHECKING, Callable, Dict, Generator, Iterable, List, Optional
  3. import numpy as np
  4. import pyarrow as pa
  5. from ray.data._internal.util import _check_pyarrow_version
  6. from ray.data.block import Block, BlockMetadata, Schema
  7. from ray.data.datasource.util import _iter_sliced_blocks
  8. from ray.data.expressions import Expr
  9. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  10. if TYPE_CHECKING:
  11. from ray.data.context import DataContext
  12. class _DatasourceProjectionPushdownMixin:
  13. """Mixin for reading operators supporting projection pushdown"""
  14. def supports_projection_pushdown(self) -> bool:
  15. """Returns ``True`` in case ``Datasource`` supports projection operation
  16. being pushed down into the reading layer"""
  17. return False
  18. def get_projection_map(self) -> Optional[Dict[str, str]]:
  19. """Return the projection map (original column names -> final column names).
  20. Returns:
  21. Dict mapping original column names (in storage) to final column names
  22. (after optional renames). Keys indicate which columns are selected.
  23. None means all columns are selected with no renames.
  24. Empty dict {} means no columns are selected.
  25. """
  26. return self._projection_map
  27. def get_column_renames(self) -> Optional[Dict[str, str]]:
  28. """Return the column renames from the projection map.
  29. This is used by predicate pushdown to rewrite filter expressions
  30. from renamed column names back to original column names.
  31. Returns:
  32. Dict mapping original column names to renamed names,
  33. or None if no renaming has been applied.
  34. """
  35. if self._projection_map is None:
  36. return None
  37. # Only include actual renames (where key != value)
  38. renames = {k: v for k, v in self._projection_map.items() if k != v}
  39. return renames if renames else None
  40. def _get_data_columns(self) -> Optional[List[str]]:
  41. """Extract data columns from projection map.
  42. Helper method for datasources that need to pass columns to legacy
  43. read functions expecting separate columns and rename_map parameters.
  44. Returns:
  45. List of column names, or None if all columns should be read.
  46. Empty list [] means no columns.
  47. """
  48. return (
  49. list(self._projection_map.keys())
  50. if self._projection_map is not None
  51. else None
  52. )
  53. @staticmethod
  54. def _combine_projection_map(
  55. prev_projection_map: Optional[Dict[str, str]],
  56. new_projection_map: Optional[Dict[str, str]],
  57. ) -> Optional[Dict[str, str]]:
  58. """Combine two projection maps via transitive composition.
  59. Args:
  60. prev_projection_map: Previous projection (original -> intermediate names)
  61. new_projection_map: New projection to apply (intermediate -> final names)
  62. Returns:
  63. Combined projection map (original -> final names)
  64. Examples:
  65. >>> # Select columns a, b with no renames
  66. >>> prev = {"a": "a", "b": "b"}
  67. >>> # Select only 'a', rename to 'x'
  68. >>> new = {"a": "x"}
  69. >>> _DatasourceProjectionPushdownMixin._combine_projection_map(prev, new)
  70. {'a': 'x'}
  71. >>> # First rename a->temp
  72. >>> prev = {"a": "temp"}
  73. >>> # Then rename temp->final
  74. >>> new = {"temp": "final"}
  75. >>> _DatasourceProjectionPushdownMixin._combine_projection_map(prev, new)
  76. {'a': 'final'}
  77. """
  78. # Handle None cases (None means "all columns, no renames")
  79. if prev_projection_map is None:
  80. return new_projection_map
  81. elif new_projection_map is None:
  82. return prev_projection_map
  83. # Compose projections: for each original->intermediate mapping in prev,
  84. # check if intermediate is selected by new projection
  85. composed = {}
  86. for orig_col, intermediate_name in prev_projection_map.items():
  87. # If intermediate name is in new projection, follow the chain
  88. if intermediate_name in new_projection_map:
  89. final_name = new_projection_map[intermediate_name]
  90. composed[orig_col] = final_name
  91. # The composition already handles transitive chains correctly:
  92. # prev {a: temp}, new {temp: final} -> composed {a: final}
  93. # No need for collapse_transitive_map which would incorrectly remove
  94. # identity mappings like {b: b}
  95. return composed
  96. def apply_projection(
  97. self,
  98. projection_map: Optional[Dict[str, str]],
  99. ) -> "Datasource":
  100. """Apply a projection to this datasource.
  101. Args:
  102. projection_map: Dict mapping original column names (in storage)
  103. to final column names (after optional renames). Keys indicate
  104. which columns to select. None means select all columns with no renames.
  105. Returns:
  106. A new datasource instance with the projection applied.
  107. """
  108. clone = copy.copy(self)
  109. # Combine projections via transitive map composition
  110. clone._projection_map = self._combine_projection_map(
  111. self._projection_map, projection_map
  112. )
  113. return clone
  114. @staticmethod
  115. def _apply_rename(
  116. table: "pa.Table",
  117. column_rename_map: Optional[Dict[str, str]],
  118. ) -> "pa.Table":
  119. """Apply column renaming to a PyArrow table.
  120. Args:
  121. table: PyArrow table to rename
  122. column_rename_map: Mapping from old column names to new names
  123. Returns:
  124. Table with renamed columns
  125. """
  126. if not column_rename_map:
  127. return table
  128. new_names = [column_rename_map.get(col, col) for col in table.schema.names]
  129. return table.rename_columns(new_names)
  130. @staticmethod
  131. def _apply_rename_to_tables(
  132. tables: Iterable["pa.Table"],
  133. column_rename_map: Optional[Dict[str, str]],
  134. ) -> Generator["pa.Table", None, None]:
  135. """Wrap a table generator to apply column renaming to each table.
  136. This helper eliminates duplication across datasources that need to apply
  137. column renames to tables yielded from generators.
  138. Args:
  139. tables: Iterator/generator yielding PyArrow tables
  140. column_rename_map: Mapping from old column names to new names
  141. Yields:
  142. pa.Table: Tables with renamed columns
  143. """
  144. for table in tables:
  145. yield _DatasourceProjectionPushdownMixin._apply_rename(
  146. table, column_rename_map
  147. )
  148. class _DatasourcePredicatePushdownMixin:
  149. """Mixin for reading operators supporting predicate pushdown"""
  150. def __init__(self):
  151. self._predicate_expr: Optional[Expr] = None
  152. def supports_predicate_pushdown(self) -> bool:
  153. return False
  154. def get_current_predicate(self) -> Optional[Expr]:
  155. return self._predicate_expr
  156. def apply_predicate(
  157. self,
  158. predicate_expr: Expr,
  159. ) -> "Datasource":
  160. """Apply a predicate to this datasource.
  161. Default implementation that combines predicates using AND.
  162. Subclasses that support predicate pushdown should have a _predicate_expr
  163. attribute to store the predicate.
  164. Note: Column rebinding is handled by the PredicatePushdown rule
  165. before this method is called, so the predicate_expr should already
  166. reference the correct column names.
  167. """
  168. import copy
  169. clone = copy.copy(self)
  170. # Combine with existing predicate using AND
  171. clone._predicate_expr = (
  172. predicate_expr
  173. if clone._predicate_expr is None
  174. else clone._predicate_expr & predicate_expr
  175. )
  176. return clone
  177. @PublicAPI
  178. class Datasource(_DatasourceProjectionPushdownMixin, _DatasourcePredicatePushdownMixin):
  179. """Interface for defining a custom :class:`~ray.data.Dataset` datasource.
  180. To read a datasource into a dataset, use :meth:`~ray.data.read_datasource`.
  181. """ # noqa: E501
  182. def __init__(self):
  183. """Initialize the datasource and its mixins."""
  184. _DatasourcePredicatePushdownMixin.__init__(self)
  185. @Deprecated
  186. def create_reader(self, **read_args) -> "Reader":
  187. """
  188. Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and
  189. :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead.
  190. """
  191. return _LegacyDatasourceReader(self, **read_args)
  192. @Deprecated
  193. def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask"]:
  194. """
  195. Deprecated: Implement :meth:`~ray.data.Datasource.get_read_tasks` and
  196. :meth:`~ray.data.Datasource.estimate_inmemory_data_size` instead.
  197. """
  198. raise NotImplementedError
  199. def get_name(self) -> str:
  200. """Return a human-readable name for this datasource.
  201. This will be used as the names of the read tasks.
  202. """
  203. name = type(self).__name__
  204. datasource_suffix = "Datasource"
  205. if name.endswith(datasource_suffix):
  206. name = name[: -len(datasource_suffix)]
  207. return name
  208. def estimate_inmemory_data_size(self) -> Optional[int]:
  209. """Return an estimate of the in-memory data size, or None if unknown.
  210. Note that the in-memory data size may be larger than the on-disk data size.
  211. """
  212. raise NotImplementedError
  213. def get_read_tasks(
  214. self,
  215. parallelism: int,
  216. per_task_row_limit: Optional[int] = None,
  217. data_context: Optional["DataContext"] = None,
  218. ) -> List["ReadTask"]:
  219. """Execute the read and return read tasks.
  220. Args:
  221. parallelism: The requested read parallelism. The number of read
  222. tasks should equal to this value if possible.
  223. per_task_row_limit: The per-task row limit for the read tasks.
  224. data_context: The data context to use to get read tasks.
  225. Returns:
  226. A list of read tasks that can be executed to read blocks from the
  227. datasource in parallel.
  228. """
  229. raise NotImplementedError
  230. @property
  231. def should_create_reader(self) -> bool:
  232. has_implemented_get_read_tasks = (
  233. type(self).get_read_tasks is not Datasource.get_read_tasks
  234. )
  235. has_implemented_estimate_inmemory_data_size = (
  236. type(self).estimate_inmemory_data_size
  237. is not Datasource.estimate_inmemory_data_size
  238. )
  239. return (
  240. not has_implemented_get_read_tasks
  241. or not has_implemented_estimate_inmemory_data_size
  242. )
  243. @property
  244. def supports_distributed_reads(self) -> bool:
  245. """If ``False``, only launch read tasks on the driver's node."""
  246. return True
  247. @Deprecated
  248. class Reader:
  249. """A bound read operation for a :class:`~ray.data.Datasource`.
  250. This is a stateful class so that reads can be prepared in multiple stages.
  251. For example, it is useful for :class:`Datasets <ray.data.Dataset>` to know the
  252. in-memory size of the read prior to executing it.
  253. """
  254. def estimate_inmemory_data_size(self) -> Optional[int]:
  255. """Return an estimate of the in-memory data size, or None if unknown.
  256. Note that the in-memory data size may be larger than the on-disk data size.
  257. """
  258. raise NotImplementedError
  259. def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
  260. """Execute the read and return read tasks.
  261. Args:
  262. parallelism: The requested read parallelism. The number of read
  263. tasks should equal to this value if possible.
  264. Returns:
  265. A list of read tasks that can be executed to read blocks from the
  266. datasource in parallel.
  267. """
  268. raise NotImplementedError
  269. class _LegacyDatasourceReader(Reader):
  270. def __init__(self, datasource: Datasource, **read_args):
  271. self._datasource = datasource
  272. self._read_args = read_args
  273. def estimate_inmemory_data_size(self) -> Optional[int]:
  274. return None
  275. def get_read_tasks(
  276. self,
  277. parallelism: int,
  278. per_task_row_limit: Optional[int] = None,
  279. data_context: Optional["DataContext"] = None,
  280. ) -> List["ReadTask"]:
  281. """Execute the read and return read tasks.
  282. Args:
  283. parallelism: The requested read parallelism. The number of read
  284. tasks should equal to this value if possible.
  285. per_task_row_limit: The per-task row limit for the read tasks.
  286. data_context: The data context to use to get read tasks. Not used by this
  287. legacy reader.
  288. Returns:
  289. A list of read tasks that can be executed to read blocks from the
  290. datasource in parallel.
  291. """
  292. return self._datasource.prepare_read(parallelism, **self._read_args)
  293. @DeveloperAPI
  294. class ReadTask(Callable[[], Iterable[Block]]):
  295. """A function used to read blocks from the :class:`~ray.data.Dataset`.
  296. Read tasks are generated by :meth:`~ray.data.Datasource.get_read_tasks`,
  297. and return a list of ``ray.data.Block`` when called. Initial metadata about the read
  298. operation can be retrieved via the ``metadata`` attribute prior to executing the
  299. read. Final metadata is returned after the read along with the blocks.
  300. Ray will execute read tasks in remote functions to parallelize execution.
  301. Note that the number of blocks returned can vary at runtime. For example,
  302. if a task is reading a single large file it can return multiple blocks to
  303. avoid running out of memory during the read.
  304. The initial metadata should reflect all the blocks returned by the read,
  305. e.g., if the metadata says ``num_rows=1000``, the read can return a single
  306. block of 1000 rows, or multiple blocks with 1000 rows altogether.
  307. The final metadata (returned with the actual block) reflects the exact
  308. contents of the block itself.
  309. """
  310. def __init__(
  311. self,
  312. read_fn: Callable[[], Iterable[Block]],
  313. metadata: BlockMetadata,
  314. schema: Optional["Schema"] = None,
  315. per_task_row_limit: Optional[int] = None,
  316. ):
  317. self._metadata = metadata
  318. self._read_fn = read_fn
  319. self._schema = schema
  320. self._per_task_row_limit = per_task_row_limit
  321. @property
  322. def metadata(self) -> BlockMetadata:
  323. return self._metadata
  324. # TODO(justin): We want to remove schema from `ReadTask` later on
  325. @property
  326. def schema(self) -> Optional["Schema"]:
  327. return self._schema
  328. @property
  329. def read_fn(self) -> Callable[[], Iterable[Block]]:
  330. return self._read_fn
  331. @property
  332. def per_task_row_limit(self) -> Optional[int]:
  333. """Get the per-task row limit for this read task."""
  334. return self._per_task_row_limit
  335. def __call__(self) -> Iterable[Block]:
  336. result = self._read_fn()
  337. if not hasattr(result, "__iter__"):
  338. DeprecationWarning(
  339. "Read function must return Iterable[Block], got {}. "
  340. "Probably you need to return `[block]` instead of "
  341. "`block`.".format(result)
  342. )
  343. if self._per_task_row_limit is None:
  344. yield from result
  345. return
  346. yield from _iter_sliced_blocks(result, self._per_task_row_limit)
  347. @DeveloperAPI
  348. class RandomIntRowDatasource(Datasource):
  349. """An example datasource that generates rows with random int64 columns.
  350. Examples:
  351. >>> import ray
  352. >>> from ray.data.datasource import RandomIntRowDatasource
  353. >>> source = RandomIntRowDatasource() # doctest: +SKIP
  354. >>> ray.data.read_datasource( # doctest: +SKIP
  355. ... source, n=10, num_columns=2).take()
  356. {'c_0': 1717767200176864416, 'c_1': 999657309586757214}
  357. {'c_0': 4983608804013926748, 'c_1': 1160140066899844087}
  358. """
  359. def __init__(self, n: int, num_columns: int):
  360. """Initialize the datasource that generates random-integer rows.
  361. Args:
  362. n: The number of rows to generate.
  363. num_columns: The number of columns to generate.
  364. """
  365. self._n = n
  366. self._num_columns = num_columns
  367. def estimate_inmemory_data_size(self) -> Optional[int]:
  368. return self._n * self._num_columns * 8
  369. def get_read_tasks(
  370. self,
  371. parallelism: int,
  372. per_task_row_limit: Optional[int] = None,
  373. data_context: Optional["DataContext"] = None,
  374. ) -> List[ReadTask]:
  375. _check_pyarrow_version()
  376. import pyarrow
  377. read_tasks: List[ReadTask] = []
  378. n = self._n
  379. num_columns = self._num_columns
  380. block_size = max(1, n // parallelism)
  381. def make_block(count: int, num_columns: int) -> Block:
  382. return pyarrow.Table.from_arrays(
  383. np.random.randint(
  384. np.iinfo(np.int64).max, size=(num_columns, count), dtype=np.int64
  385. ),
  386. names=[f"c_{i}" for i in range(num_columns)],
  387. )
  388. schema = pyarrow.Table.from_pydict(
  389. {f"c_{i}": [0] for i in range(num_columns)}
  390. ).schema
  391. i = 0
  392. while i < n:
  393. count = min(block_size, n - i)
  394. meta = BlockMetadata(
  395. num_rows=count,
  396. size_bytes=8 * count * num_columns,
  397. input_files=None,
  398. exec_stats=None,
  399. )
  400. read_tasks.append(
  401. ReadTask(
  402. lambda count=count, num_columns=num_columns: [
  403. make_block(count, num_columns)
  404. ],
  405. meta,
  406. schema=schema,
  407. per_task_row_limit=per_task_row_limit,
  408. )
  409. )
  410. i += block_size
  411. return read_tasks
  412. def get_name(self) -> str:
  413. """Return a human-readable name for this datasource.
  414. This will be used as the names of the read tasks.
  415. Note: overrides the base `Datasource` method.
  416. """
  417. return "RandomInt"