filename_provider.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from typing import Any, Dict, Optional
  2. from ray.data.block import Block
  3. from ray.util.annotations import PublicAPI
  4. @PublicAPI(stability="alpha")
  5. class FilenameProvider:
  6. """Generates filenames when you write a :class:`~ray.data.Dataset`.
  7. Use this class to customize the filenames used when writing a Dataset.
  8. Some methods write each row to a separate file, while others write each block to a
  9. separate file. For example, :meth:`ray.data.Dataset.write_images` writes individual
  10. rows, and :func:`ray.data.Dataset.write_parquet` writes blocks of data. For more
  11. information about blocks, see :ref:`Data internals <datasets_scheduling>`.
  12. If you're writing each row to a separate file, implement
  13. :meth:`~FilenameProvider.get_filename_for_row`. Otherwise, implement
  14. :meth:`~FilenameProvider.get_filename_for_block`.
  15. Example:
  16. This snippet shows you how to encode labels in written files. For example, if
  17. `"cat"` is a label, you might write a file named `cat_000000_000000_000000.png`.
  18. .. testcode::
  19. import ray
  20. from ray.data.datasource import FilenameProvider
  21. class ImageFilenameProvider(FilenameProvider):
  22. def __init__(self, file_format: str):
  23. self.file_format = file_format
  24. def get_filename_for_row(self, row, write_uuid, task_index, block_index, row_index):
  25. return (
  26. f"{row['label']}_{write_uuid}_{task_index:06}_{block_index:06}"
  27. f"_{row_index:06}.{self.file_format}"
  28. )
  29. ds = ray.data.read_parquet("s3://anonymous@ray-example-data/images.parquet")
  30. ds.write_images(
  31. "/tmp/results",
  32. column="image",
  33. filename_provider=ImageFilenameProvider("png")
  34. )
  35. """ # noqa: E501
  36. def get_filename_for_block(
  37. self, block: Block, write_uuid: str, task_index: int, block_index: int
  38. ) -> str:
  39. """Generate a filename for a block of data.
  40. .. note::
  41. Filenames must be unique and deterministic for a given write UUID, and
  42. task and block index.
  43. A block consists of multiple rows and corresponds to a single output file.
  44. Each task might produce a different number of blocks.
  45. Args:
  46. block: The block that will be written to a file.
  47. write_uuid: The UUID of the write operation.
  48. task_index: The index of the write task.
  49. block_index: The index of the block *within* the write task.
  50. """
  51. raise NotImplementedError
  52. def get_filename_for_row(
  53. self,
  54. row: Dict[str, Any],
  55. write_uuid: str,
  56. task_index: int,
  57. block_index: int,
  58. row_index: int,
  59. ) -> str:
  60. """Generate a filename for a row.
  61. .. note::
  62. Filenames must be unique and deterministic for a given write UUID, and
  63. task, block, and row index.
  64. A block consists of multiple rows, and each row corresponds to a single
  65. output file. Each task might produce a different number of blocks, and each
  66. block might contain a different number of rows.
  67. .. tip::
  68. If you require a contiguous row index into the global dataset, use
  69. :meth:`~ray.data.Dataset.iter_rows`. This method is single-threaded and
  70. isn't recommended for large datasets.
  71. Args:
  72. row: The row that will be written to a file.
  73. write_uuid: The UUID of the write operation.
  74. task_index: The index of the write task.
  75. block_index: The index of the block *within* the write task.
  76. row_index: The index of the row *within* the block.
  77. """
  78. raise NotImplementedError
  79. class _DefaultFilenameProvider(FilenameProvider):
  80. def __init__(
  81. self, dataset_uuid: Optional[str] = None, file_format: Optional[str] = None
  82. ):
  83. self._dataset_uuid = dataset_uuid
  84. self._file_format = file_format
  85. def get_filename_for_block(
  86. self, block: Block, write_uuid: str, task_index: int, block_index: int
  87. ) -> str:
  88. file_id = f"{write_uuid}_{task_index:06}_{block_index:06}"
  89. return self._generate_filename(file_id)
  90. def get_filename_for_row(
  91. self,
  92. row: Dict[str, Any],
  93. write_uuid: str,
  94. task_index: int,
  95. block_index: int,
  96. row_index: int,
  97. ) -> str:
  98. file_id = f"{write_uuid}_{task_index:06}_{block_index:06}_{row_index:06}"
  99. return self._generate_filename(file_id)
  100. def _generate_filename(self, file_id: str) -> str:
  101. filename = ""
  102. if self._dataset_uuid is not None:
  103. filename += f"{self._dataset_uuid}_"
  104. filename += file_id
  105. if self._file_format is not None:
  106. filename += f".{self._file_format}"
  107. return filename