file_based_datasource.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. import io
  2. import logging
  3. from dataclasses import dataclass
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Callable,
  8. Dict,
  9. Iterable,
  10. Iterator,
  11. List,
  12. Literal,
  13. Optional,
  14. Tuple,
  15. TypeVar,
  16. Union,
  17. )
  18. import numpy as np
  19. import ray
  20. from ray.data._internal.util import (
  21. RetryingContextManager,
  22. RetryingPyFileSystem,
  23. _check_pyarrow_version,
  24. _is_local_scheme,
  25. infer_compression,
  26. iterate_with_retry,
  27. make_async_gen,
  28. )
  29. from ray.data.block import Block, BlockAccessor
  30. from ray.data.context import DataContext
  31. from ray.data.datasource.datasource import Datasource, ReadTask
  32. from ray.data.datasource.file_meta_provider import (
  33. BaseFileMetadataProvider,
  34. DefaultFileMetadataProvider,
  35. )
  36. from ray.data.datasource.partitioning import (
  37. Partitioning,
  38. PathPartitionFilter,
  39. PathPartitionParser,
  40. )
  41. from ray.data.datasource.path_util import (
  42. _has_file_extension,
  43. _resolve_paths_and_filesystem,
  44. )
  45. from ray.util.annotations import DeveloperAPI
  46. if TYPE_CHECKING:
  47. import pandas as pd
  48. import pyarrow
  49. logger = logging.getLogger(__name__)
  50. # We should parallelize file size fetch operations beyond this threshold.
  51. FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16
  52. # 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
  53. PATHS_PER_FILE_SIZE_FETCH_TASK = 16
  54. @DeveloperAPI
  55. @dataclass
  56. class FileShuffleConfig:
  57. """Configuration for file shuffling.
  58. This configuration object controls how files are shuffled while reading file-based
  59. datasets. The random seed behavior is determined by the combination of ``seed``
  60. and ``reseed_after_execution``:
  61. - If ``seed`` is None, the random seed is always None (non-deterministic shuffling).
  62. - If ``seed`` is not None and ``reseed_after_execution`` is False, the random seed is
  63. constantly ``seed`` across executions.
  64. - If ``seed`` is not None and ``reseed_after_execution`` is True, the random seed is
  65. different for each execution.
  66. .. note::
  67. Even if you provided a seed, you might still observe a non-deterministic row
  68. order. This is because tasks are executed in parallel and their completion
  69. order might vary. If you need to preserve the order of rows, set
  70. ``DataContext.get_current().execution_options.preserve_order``.
  71. Args:
  72. seed: An optional integer seed for the file shuffler. If None, shuffling is
  73. non-deterministic. If provided, shuffling is deterministic based on this
  74. seed and the ``reseed_after_execution`` setting.
  75. reseed_after_execution: If True, the random seed considers both ``seed`` and
  76. ``execution_idx``, resulting in different shuffling orders across executions.
  77. If False, the random seed is constantly ``seed``, resulting in the same
  78. shuffling order across executions. Only takes effect when ``seed`` is not None.
  79. Defaults to True.
  80. Example:
  81. >>> import ray
  82. >>> from ray.data import FileShuffleConfig
  83. >>> # Fixed seed - same shuffle across executions
  84. >>> shuffle = FileShuffleConfig(seed=42, reseed_after_execution=False)
  85. >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea", shuffle=shuffle)
  86. >>>
  87. >>> # Seed with reseed_after_execution - different shuffle per execution
  88. >>> shuffle = FileShuffleConfig(seed=42, reseed_after_execution=True)
  89. >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea", shuffle=shuffle)
  90. """ # noqa: E501
  91. seed: Optional[int] = None
  92. reseed_after_execution: bool = True
  93. def __post_init__(self):
  94. """Ensure that the seed is either None or an integer."""
  95. if self.seed is not None and not isinstance(self.seed, int):
  96. raise ValueError("Seed must be an integer or None.")
  97. def get_seed(self, execution_idx: int = 0) -> Optional[int]:
  98. if self.seed is None:
  99. return None
  100. elif self.reseed_after_execution:
  101. # Modulo ensures the result is in valid NumPy seed range [0, 2**32 - 1].
  102. return hash((self.seed, execution_idx)) % (2**32)
  103. else:
  104. return self.seed
  105. @DeveloperAPI
  106. class FileBasedDatasource(Datasource):
  107. """File-based datasource for reading files.
  108. Don't use this class directly. Instead, subclass it and implement `_read_stream()`.
  109. """
  110. # If `_WRITE_FILE_PER_ROW` is `True`, this datasource calls `_write_row` and writes
  111. # each row to a file. Otherwise, this datasource calls `_write_block` and writes
  112. # each block to a file.
  113. _WRITE_FILE_PER_ROW = False
  114. _FILE_EXTENSIONS: Optional[Union[str, List[str]]] = None
  115. # Number of threads for concurrent reading within each read task.
  116. # If zero or negative, reading will be performed in the main thread.
  117. _NUM_THREADS_PER_TASK = 0
  118. def __init__(
  119. self,
  120. paths: Union[str, List[str]],
  121. *,
  122. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  123. schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
  124. open_stream_args: Optional[Dict[str, Any]] = None,
  125. meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
  126. partition_filter: PathPartitionFilter = None,
  127. partitioning: Partitioning = None,
  128. ignore_missing_paths: bool = False,
  129. shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None,
  130. include_paths: bool = False,
  131. file_extensions: Optional[List[str]] = None,
  132. ):
  133. super().__init__()
  134. _check_pyarrow_version()
  135. self._supports_distributed_reads = not _is_local_scheme(paths)
  136. if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
  137. raise ValueError(
  138. "Because you're using Ray Client, read tasks scheduled on the Ray "
  139. "cluster can't access your local files. To fix this issue, store "
  140. "files in cloud storage or a distributed filesystem like NFS."
  141. )
  142. self._schema = schema
  143. self._data_context = DataContext.get_current()
  144. self._open_stream_args = open_stream_args
  145. self._meta_provider = meta_provider
  146. self._partition_filter = partition_filter
  147. self._partitioning = partitioning
  148. self._ignore_missing_paths = ignore_missing_paths
  149. self._include_paths = include_paths
  150. # Need this property for lineage tracking. We should not directly assign paths
  151. # to self since it is captured every read_task_fn during serialization and
  152. # causing this data being duplicated and excessive object store spilling.
  153. self._source_paths_ref = ray.put(paths)
  154. paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
  155. self._filesystem = RetryingPyFileSystem.wrap(
  156. self._filesystem, retryable_errors=self._data_context.retried_io_errors
  157. )
  158. paths, file_sizes = map(
  159. list,
  160. zip(
  161. *meta_provider.expand_paths(
  162. paths,
  163. self._filesystem,
  164. partitioning,
  165. ignore_missing_paths=ignore_missing_paths,
  166. )
  167. ),
  168. )
  169. if ignore_missing_paths and len(paths) == 0:
  170. raise ValueError(
  171. "None of the provided paths exist. "
  172. "The 'ignore_missing_paths' field is set to True."
  173. )
  174. if self._partition_filter is not None:
  175. # Use partition filter to skip files which are not needed.
  176. path_to_size = dict(zip(paths, file_sizes))
  177. paths = self._partition_filter(paths)
  178. file_sizes = [path_to_size[p] for p in paths]
  179. if len(paths) == 0:
  180. raise ValueError(
  181. "No input files found to read. Please double check that "
  182. "'partition_filter' field is set properly."
  183. )
  184. if file_extensions is not None:
  185. path_to_size = dict(zip(paths, file_sizes))
  186. paths = [p for p in paths if _has_file_extension(p, file_extensions)]
  187. file_sizes = [path_to_size[p] for p in paths]
  188. if len(paths) == 0:
  189. raise ValueError(
  190. "No input files found to read with the following file extensions: "
  191. f"{file_extensions}. Please double check that "
  192. "'file_extensions' field is set properly."
  193. )
  194. _validate_shuffle_arg(shuffle)
  195. self._shuffle = shuffle
  196. # Read tasks serialize `FileBasedDatasource` instances, and the list of paths
  197. # can be large. To avoid slow serialization speeds, we store a reference to
  198. # the paths rather than the paths themselves.
  199. self._paths_ref = ray.put(paths)
  200. self._file_sizes_ref = ray.put(file_sizes)
  201. @property
  202. def _source_paths(self) -> List[str]:
  203. return ray.get(self._source_paths_ref)
  204. def _paths(self) -> List[str]:
  205. return ray.get(self._paths_ref)
  206. def _file_sizes(self) -> List[float]:
  207. return ray.get(self._file_sizes_ref)
  208. def estimate_inmemory_data_size(self) -> Optional[int]:
  209. total_size = 0
  210. for sz in self._file_sizes():
  211. if sz is not None:
  212. total_size += sz
  213. return total_size
  214. def get_read_tasks(
  215. self,
  216. parallelism: int,
  217. per_task_row_limit: Optional[int] = None,
  218. data_context: Optional["DataContext"] = None,
  219. ) -> List[ReadTask]:
  220. import numpy as np
  221. open_stream_args = self._open_stream_args
  222. partitioning = self._partitioning
  223. paths = self._paths()
  224. file_sizes = self._file_sizes()
  225. execution_idx = data_context._execution_idx if data_context is not None else 0
  226. paths, file_sizes = _shuffle_file_metadata(
  227. paths, file_sizes, self._shuffle, execution_idx
  228. )
  229. filesystem = _wrap_s3_serialization_workaround(self._filesystem)
  230. if open_stream_args is None:
  231. open_stream_args = {}
  232. def read_files(
  233. read_paths: Iterable[str],
  234. ) -> Iterable[Block]:
  235. nonlocal filesystem, open_stream_args, partitioning
  236. fs = _unwrap_s3_serialization_workaround(filesystem)
  237. for read_path in read_paths:
  238. partitions: Dict[str, str] = {}
  239. if partitioning is not None:
  240. parse = PathPartitionParser(partitioning)
  241. partitions = parse(read_path)
  242. with RetryingContextManager(
  243. self._open_input_source(fs, read_path, **open_stream_args),
  244. context=self._data_context,
  245. ) as f:
  246. for block in iterate_with_retry(
  247. lambda: self._read_stream(f, read_path),
  248. description="read stream iteratively",
  249. match=self._data_context.retried_io_errors,
  250. ):
  251. if partitions:
  252. block = _add_partitions(block, partitions)
  253. if self._include_paths:
  254. block_accessor = BlockAccessor.for_block(block)
  255. block = block_accessor.fill_column("path", read_path)
  256. yield block
  257. def create_read_task_fn(read_paths, num_threads):
  258. def read_task_fn():
  259. nonlocal num_threads, read_paths
  260. # TODO: We should refactor the code so that we can get the results in
  261. # order even when using multiple threads.
  262. if self._data_context.execution_options.preserve_order:
  263. num_threads = 0
  264. if num_threads > 0:
  265. num_threads = min(num_threads, len(read_paths))
  266. logger.debug(
  267. f"Reading {len(read_paths)} files with {num_threads} threads."
  268. )
  269. yield from make_async_gen(
  270. iter(read_paths),
  271. read_files,
  272. num_workers=num_threads,
  273. preserve_ordering=True,
  274. )
  275. else:
  276. logger.debug(f"Reading {len(read_paths)} files.")
  277. yield from read_files(read_paths)
  278. return read_task_fn
  279. # fix https://github.com/ray-project/ray/issues/24296
  280. parallelism = min(parallelism, len(paths))
  281. read_tasks = []
  282. split_paths = np.array_split(paths, parallelism)
  283. split_file_sizes = np.array_split(file_sizes, parallelism)
  284. for read_paths, file_sizes in zip(split_paths, split_file_sizes):
  285. if len(read_paths) <= 0:
  286. continue
  287. meta = self._meta_provider(
  288. read_paths,
  289. rows_per_file=self._rows_per_file(),
  290. file_sizes=file_sizes,
  291. )
  292. read_task_fn = create_read_task_fn(read_paths, self._NUM_THREADS_PER_TASK)
  293. read_task = ReadTask(
  294. read_task_fn, meta, per_task_row_limit=per_task_row_limit
  295. )
  296. read_tasks.append(read_task)
  297. return read_tasks
  298. def resolve_compression(
  299. self, path: str, open_args: Dict[str, Any]
  300. ) -> Optional[str]:
  301. """Resolves the compression format for a stream.
  302. Args:
  303. path: The file path to resolve compression for.
  304. open_args: kwargs passed to
  305. `pyarrow.fs.FileSystem.open_input_stream <https://arrow.apache.org/docs/python/generated/pyarrow.fs.FileSystem.html#pyarrow.fs.FileSystem.open_input_stream>`_
  306. when opening input files to read.
  307. Returns:
  308. The compression format (e.g., "gzip", "snappy", "bz2") or None if
  309. no compression is detected or specified.
  310. """
  311. compression = open_args.get("compression", None)
  312. if compression is None:
  313. compression = infer_compression(path)
  314. return compression
  315. def _resolve_buffer_size(self, open_args: Dict[str, Any]) -> Optional[int]:
  316. buffer_size = open_args.pop("buffer_size", None)
  317. if buffer_size is None:
  318. buffer_size = self._data_context.streaming_read_buffer_size
  319. return buffer_size
  320. def _file_to_snappy_stream(
  321. self,
  322. file: "pyarrow.NativeFile",
  323. filesystem: "RetryingPyFileSystem",
  324. ) -> "pyarrow.PythonFile":
  325. import pyarrow as pa
  326. import snappy
  327. from pyarrow.fs import HadoopFileSystem
  328. stream = io.BytesIO()
  329. if isinstance(filesystem.unwrap(), HadoopFileSystem):
  330. snappy.hadoop_snappy.stream_decompress(src=file, dst=stream)
  331. else:
  332. snappy.stream_decompress(src=file, dst=stream)
  333. stream.seek(0)
  334. return pa.PythonFile(stream, mode="r")
  335. def _open_input_source(
  336. self,
  337. filesystem: "RetryingPyFileSystem",
  338. path: str,
  339. **open_args,
  340. ) -> "pyarrow.NativeFile":
  341. """Opens a source path for reading and returns the associated Arrow NativeFile.
  342. The default implementation opens the source path as a sequential input stream,
  343. using self._data_context.streaming_read_buffer_size as the buffer size if none
  344. is given by the caller.
  345. Implementations that do not support streaming reads (e.g. that require random
  346. access) should override this method.
  347. """
  348. compression = self.resolve_compression(path, open_args)
  349. buffer_size = self._resolve_buffer_size(open_args)
  350. if compression == "snappy":
  351. # Arrow doesn't support streaming Snappy decompression since the canonical
  352. # C++ Snappy library doesn't natively support streaming decompression. We
  353. # works around this by manually decompressing the file with python-snappy.
  354. open_args["compression"] = None
  355. file = filesystem.open_input_stream(
  356. path, buffer_size=buffer_size, **open_args
  357. )
  358. return self._file_to_snappy_stream(file, filesystem)
  359. open_args["compression"] = compression
  360. return filesystem.open_input_stream(path, buffer_size=buffer_size, **open_args)
  361. def _rows_per_file(self):
  362. """Returns the number of rows per file, or None if unknown."""
  363. return None
  364. def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
  365. """Streaming read a single file.
  366. This method should be implemented by subclasses.
  367. """
  368. raise NotImplementedError(
  369. "Subclasses of FileBasedDatasource must implement _read_stream()."
  370. )
  371. @property
  372. def supports_distributed_reads(self) -> bool:
  373. return self._supports_distributed_reads
  374. def _add_partitions(
  375. data: Union["pyarrow.Table", "pd.DataFrame"], partitions: Dict[str, Any]
  376. ) -> Union["pyarrow.Table", "pd.DataFrame"]:
  377. import pandas as pd
  378. import pyarrow as pa
  379. assert isinstance(data, (pa.Table, pd.DataFrame))
  380. if isinstance(data, pa.Table):
  381. return _add_partitions_to_table(data, partitions)
  382. if isinstance(data, pd.DataFrame):
  383. return _add_partitions_to_dataframe(data, partitions)
  384. def _add_partitions_to_table(
  385. table: "pyarrow.Table", partitions: Dict[str, Any]
  386. ) -> "pyarrow.Table":
  387. import pyarrow as pa
  388. import pyarrow.compute as pc
  389. column_names = set(table.column_names)
  390. for field, value in partitions.items():
  391. column = pa.array([value] * len(table))
  392. if field in column_names:
  393. # TODO: Handle cast error.
  394. column_type = table.schema.field(field).type
  395. column = column.cast(column_type)
  396. values_are_equal = pc.all(pc.equal(column, table[field]))
  397. values_are_equal = values_are_equal.as_py()
  398. if not values_are_equal:
  399. raise ValueError(
  400. f"Partition column {field} exists in table data, but partition "
  401. f"value '{value}' is different from in-data values: "
  402. f"{table[field].unique().to_pylist()}."
  403. )
  404. i = table.schema.get_field_index(field)
  405. table = table.set_column(i, field, column)
  406. else:
  407. table = table.append_column(field, column)
  408. return table
  409. def _add_partitions_to_dataframe(
  410. df: "pd.DataFrame", partitions: Dict[str, Any]
  411. ) -> "pd.DataFrame":
  412. import pandas as pd
  413. for field, value in partitions.items():
  414. column = pd.Series(data=[value] * len(df), name=field)
  415. if field in df:
  416. column = column.astype(df[field].dtype)
  417. mask = df[field].notna()
  418. if not df[field][mask].equals(column[mask]):
  419. raise ValueError(
  420. f"Partition column {field} exists in table data, but partition "
  421. f"value '{value}' is different from in-data values: "
  422. f"{list(df[field].unique())}."
  423. )
  424. df[field] = column
  425. return df
  426. def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"):
  427. # This is needed because pa.fs.S3FileSystem assumes pa.fs is already
  428. # imported before deserialization. See #17085.
  429. import pyarrow as pa
  430. import pyarrow.fs
  431. base_fs = filesystem
  432. if isinstance(filesystem, RetryingPyFileSystem):
  433. base_fs = filesystem.unwrap()
  434. if isinstance(base_fs, pa.fs.S3FileSystem):
  435. return _S3FileSystemWrapper(filesystem)
  436. return filesystem
  437. def _unwrap_s3_serialization_workaround(
  438. filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"],
  439. ):
  440. if isinstance(filesystem, _S3FileSystemWrapper):
  441. filesystem = filesystem.unwrap()
  442. return filesystem
  443. class _S3FileSystemWrapper:
  444. """pyarrow.fs.S3FileSystem wrapper that can be deserialized safely.
  445. Importing pyarrow.fs during reconstruction triggers the pyarrow
  446. S3 subsystem initialization.
  447. NOTE: This is only needed for pyarrow<14.0.0 and should be removed
  448. once the minimum supported pyarrow version exceeds that.
  449. See https://github.com/apache/arrow/pull/38375 for context.
  450. """
  451. def __init__(self, fs: "pyarrow.fs.FileSystem"):
  452. self._fs = fs
  453. def unwrap(self):
  454. return self._fs
  455. @classmethod
  456. def _reconstruct(cls, fs_reconstruct, fs_args):
  457. # Implicitly trigger S3 subsystem initialization by importing
  458. # pyarrow.fs.
  459. import pyarrow.fs # noqa: F401
  460. return cls(fs_reconstruct(*fs_args))
  461. def __reduce__(self):
  462. return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__()
  463. def _resolve_kwargs(
  464. kwargs_fn: Callable[[], Dict[str, Any]], **kwargs
  465. ) -> Dict[str, Any]:
  466. if kwargs_fn:
  467. kwarg_overrides = kwargs_fn()
  468. kwargs.update(kwarg_overrides)
  469. return kwargs
  470. def _validate_shuffle_arg(
  471. shuffle: Union[Literal["files"], FileShuffleConfig, None],
  472. ) -> None:
  473. if not (
  474. shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig)
  475. ):
  476. raise ValueError(
  477. f"Invalid value for 'shuffle': {shuffle}. "
  478. "Valid values are None, 'files', `FileShuffleConfig`."
  479. )
  480. FileMetadata = TypeVar("FileMetadata")
  481. def _shuffle_file_metadata(
  482. paths: List[str],
  483. file_metadata: List[FileMetadata],
  484. shuffler: Union[Literal["files"], FileShuffleConfig, None],
  485. execution_idx: int,
  486. ) -> Tuple[List[str], List[FileMetadata]]:
  487. """Shuffle file paths and sizes together using the given shuffler."""
  488. if shuffler is None:
  489. return paths, file_metadata
  490. assert len(paths) == len(file_metadata), (
  491. "Number of paths and file metadata must match. "
  492. f"Got {len(paths)} paths and {len(file_metadata)} file metadata."
  493. )
  494. if len(paths) == 0:
  495. return paths, file_metadata
  496. if shuffler == "files":
  497. seed = None
  498. else:
  499. assert isinstance(shuffler, FileShuffleConfig)
  500. seed = shuffler.get_seed(execution_idx)
  501. file_metadata_shuffler = np.random.default_rng(seed)
  502. files_metadata = list(zip(paths, file_metadata))
  503. file_metadata_shuffler.shuffle(files_metadata)
  504. return list(map(list, zip(*files_metadata)))