file_datasink.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import logging
  2. import posixpath
  3. from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
  4. from urllib.parse import urlparse
  5. from ray._common.retry import call_with_retry
  6. from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
  7. from ray.data._internal.execution.interfaces import TaskContext
  8. from ray.data._internal.planner.plan_write_op import WRITE_UUID_KWARG_NAME
  9. from ray.data._internal.savemode import SaveMode
  10. from ray.data._internal.util import (
  11. RetryingPyFileSystem,
  12. _is_local_scheme,
  13. )
  14. from ray.data._internal.utils.arrow_utils import add_creatable_buckets_param_if_s3_uri
  15. from ray.data.block import Block, BlockAccessor
  16. from ray.data.context import DataContext
  17. from ray.data.datasource.datasink import Datasink, WriteResult
  18. from ray.data.datasource.filename_provider import (
  19. FilenameProvider,
  20. _DefaultFilenameProvider,
  21. )
  22. from ray.data.datasource.path_util import _resolve_paths_and_filesystem
  23. from ray.util.annotations import DeveloperAPI
  24. if TYPE_CHECKING:
  25. import pyarrow
  26. logger = logging.getLogger(__name__)
  27. class _FileDatasink(Datasink[None]):
  28. def __init__(
  29. self,
  30. path: str,
  31. *,
  32. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  33. try_create_dir: bool = True,
  34. open_stream_args: Optional[Dict[str, Any]] = None,
  35. filename_provider: Optional[FilenameProvider] = None,
  36. dataset_uuid: Optional[str] = None,
  37. file_format: Optional[str] = None,
  38. mode: SaveMode = SaveMode.APPEND,
  39. ):
  40. """Initialize this datasink.
  41. Args:
  42. path: The folder to write files to.
  43. filesystem: The filesystem to write files to. If not provided, the
  44. filesystem is inferred from the path.
  45. try_create_dir: Whether to create the directory to write files to.
  46. open_stream_args: Arguments to pass to ``filesystem.open_output_stream``.
  47. filename_provider: A :class:`ray.data.datasource.FilenameProvider` that
  48. generates filenames for each row or block.
  49. dataset_uuid: The UUID of the dataset being written. If specified, it's
  50. included in the filename.
  51. file_format: The file extension. If specified, files are written with this
  52. extension.
  53. """
  54. if open_stream_args is None:
  55. open_stream_args = {}
  56. if filename_provider is None:
  57. filename_provider = _DefaultFilenameProvider(
  58. dataset_uuid=dataset_uuid, file_format=file_format
  59. )
  60. self._data_context = DataContext.get_current()
  61. self.unresolved_path = path
  62. paths, self.filesystem = _resolve_paths_and_filesystem(path, filesystem)
  63. self.filesystem = RetryingPyFileSystem.wrap(
  64. self.filesystem, retryable_errors=self._data_context.retried_io_errors
  65. )
  66. assert len(paths) == 1, len(paths)
  67. self.path = paths[0]
  68. self.try_create_dir = try_create_dir
  69. self.open_stream_args = open_stream_args
  70. self.filename_provider = filename_provider
  71. self.dataset_uuid = dataset_uuid
  72. self.file_format = file_format
  73. self.mode = mode
  74. self.has_created_dir = False
  75. self._skip_write = False
  76. self._write_started = False
  77. def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
  78. return self.filesystem.open_output_stream(path, **self.open_stream_args)
  79. def on_write_start(self, schema: Optional["pyarrow.Schema"] = None) -> None:
  80. # Make idempotent - if already called, return early.
  81. if self._write_started:
  82. return
  83. self._write_started = True
  84. from pyarrow.fs import FileType
  85. dir_exists = (
  86. self.filesystem.get_file_info(self.path).type is not FileType.NotFound
  87. )
  88. if dir_exists:
  89. if self.mode == SaveMode.ERROR:
  90. raise ValueError(
  91. f"Path {self.path} already exists. "
  92. "If this is unexpected, use mode='ignore' to ignore those files"
  93. )
  94. if self.mode == SaveMode.IGNORE:
  95. logger.warning(f"[SaveMode={self.mode}] Skipping {self.path}")
  96. self._skip_write = True
  97. return
  98. if self.mode == SaveMode.OVERWRITE:
  99. logger.warning(f"[SaveMode={self.mode}] Replacing contents {self.path}")
  100. self.filesystem.delete_dir_contents(self.path)
  101. self.has_created_dir = self._create_dir(self.path)
  102. def _create_dir(self, dest) -> bool:
  103. """Create a directory to write files to.
  104. If ``try_create_dir`` is ``False``, this method is a no-op.
  105. """
  106. from pyarrow.fs import FileType
  107. # We should skip creating directories in s3 unless the user specifically
  108. # overrides this behavior. PyArrow's s3fs implementation for create_dir
  109. # will attempt to check if the parent directory exists before trying to
  110. # create the directory (with recursive=True it will try to do this to
  111. # all of the directories until the root of the bucket). An IAM Policy that
  112. # restricts access to a subset of prefixes within the bucket might cause
  113. # the creation of the directory to fail even if the permissions should
  114. # allow the data can be written to the specified path. For example if a
  115. # a policy only allows users to write blobs prefixed with s3://bucket/foo
  116. # a call to create_dir for s3://bucket/foo/bar will fail even though it
  117. # should not.
  118. parsed_uri = urlparse(dest)
  119. is_s3_uri = parsed_uri.scheme == "s3"
  120. skip_create_dir_for_s3 = is_s3_uri and not self._data_context.s3_try_create_dir
  121. if self.try_create_dir and not skip_create_dir_for_s3:
  122. if self.filesystem.get_file_info(dest).type is FileType.NotFound:
  123. # Arrow's S3FileSystem doesn't allow creating buckets by default, so we
  124. # add a query arg enabling bucket creation if an S3 URI is provided.
  125. tmp = add_creatable_buckets_param_if_s3_uri(dest)
  126. self.filesystem.create_dir(tmp, recursive=True)
  127. return True
  128. return False
  129. def write(
  130. self,
  131. blocks: Iterable[Block],
  132. ctx: TaskContext,
  133. ) -> None:
  134. builder = DelegatingBlockBuilder()
  135. for block in blocks:
  136. builder.add_block(block)
  137. block = builder.build()
  138. block_accessor = BlockAccessor.for_block(block)
  139. if block_accessor.num_rows() == 0:
  140. logger.warning(f"Skipped writing empty block to {self.path}")
  141. return
  142. self.write_block(block_accessor, 0, ctx)
  143. def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
  144. raise NotImplementedError
  145. def on_write_complete(self, write_result: WriteResult[None]):
  146. # If no rows were written, we can delete the directory.
  147. if self.has_created_dir and write_result.num_rows == 0:
  148. self.filesystem.delete_dir(self.path)
  149. @property
  150. def supports_distributed_writes(self) -> bool:
  151. return not _is_local_scheme(self.unresolved_path)
  152. @DeveloperAPI
  153. class RowBasedFileDatasink(_FileDatasink):
  154. """A datasink that writes one row to each file.
  155. Subclasses must implement ``write_row_to_file`` and call the superclass constructor.
  156. Examples:
  157. .. testcode::
  158. import io
  159. from typing import Any, Dict
  160. import pyarrow
  161. from PIL import Image
  162. from ray.data.datasource import RowBasedFileDatasink
  163. class ImageDatasink(RowBasedFileDatasink):
  164. def __init__(self, path: str, *, column: str, file_format: str = "png"):
  165. super().__init__(path, file_format=file_format)
  166. self._file_format = file_format
  167. self._column = column
  168. def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
  169. image = Image.fromarray(row[self._column])
  170. buffer = io.BytesIO()
  171. image.save(buffer, format=self._file_format)
  172. file.write(buffer.getvalue())
  173. """ # noqa: E501
  174. def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
  175. """Write a row to a file.
  176. Args:
  177. row: The row to write.
  178. file: The file to write the row to.
  179. """
  180. raise NotImplementedError
  181. def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
  182. for row_index, row in enumerate(block.iter_rows(public_row_format=False)):
  183. filename = self.filename_provider.get_filename_for_row(
  184. row,
  185. ctx.kwargs[WRITE_UUID_KWARG_NAME],
  186. ctx.task_idx,
  187. block_index,
  188. row_index,
  189. )
  190. write_path = posixpath.join(self.path, filename)
  191. logger.debug(f"Writing {write_path} file.")
  192. def write_row_to_path():
  193. with self.open_output_stream(write_path) as file:
  194. self.write_row_to_file(row, file)
  195. call_with_retry(
  196. write_row_to_path,
  197. description=f"write '{write_path}'",
  198. match=self._data_context.retried_io_errors,
  199. )
  200. @DeveloperAPI
  201. class BlockBasedFileDatasink(_FileDatasink):
  202. """A datasink that writes multiple rows to each file.
  203. Subclasses must implement ``write_block_to_file`` and call the superclass
  204. constructor.
  205. Examples:
  206. .. testcode::
  207. class CSVDatasink(BlockBasedFileDatasink):
  208. def __init__(self, path: str):
  209. super().__init__(path, file_format="csv")
  210. def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
  211. from pyarrow import csv
  212. csv.write_csv(block.to_arrow(), file)
  213. """ # noqa: E501
  214. def __init__(
  215. self, path, *, min_rows_per_file: Optional[int] = None, **file_datasink_kwargs
  216. ):
  217. super().__init__(path, **file_datasink_kwargs)
  218. self._min_rows_per_file = min_rows_per_file
  219. def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
  220. """Write a block of data to a file.
  221. Args:
  222. block: The block to write.
  223. file: The file to write the block to.
  224. """
  225. raise NotImplementedError
  226. def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
  227. filename = self.filename_provider.get_filename_for_block(
  228. block, ctx.kwargs[WRITE_UUID_KWARG_NAME], ctx.task_idx, block_index
  229. )
  230. write_path = posixpath.join(self.path, filename)
  231. def write_block_to_path():
  232. with self.open_output_stream(write_path) as file:
  233. self.write_block_to_file(block, file)
  234. logger.debug(f"Writing {write_path} file.")
  235. call_with_retry(
  236. write_block_to_path,
  237. description=f"write '{write_path}'",
  238. match=self._data_context.retried_io_errors,
  239. )
  240. @property
  241. def min_rows_per_write(self) -> Optional[int]:
  242. return self._min_rows_per_file