datasink.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import itertools
  2. import logging
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar
  5. import ray
  6. from ray.data._internal.execution.interfaces import TaskContext
  7. from ray.data.block import Block, BlockAccessor
  8. from ray.util.annotations import DeveloperAPI
  9. if TYPE_CHECKING:
  10. import pyarrow as pa
  11. logger = logging.getLogger(__name__)
  12. WriteReturnType = TypeVar("WriteReturnType")
  13. """Generic type for the return value of `Datasink.write`."""
  14. @dataclass
  15. @DeveloperAPI
  16. class WriteResult(Generic[WriteReturnType]):
  17. """Aggregated result of the Datasink write operations."""
  18. # Total number of written rows.
  19. num_rows: int
  20. # Total size in bytes of written data.
  21. size_bytes: int
  22. # All returned values of `Datasink.write`.
  23. write_returns: List[WriteReturnType]
  24. @classmethod
  25. def combine(cls, *wrs: "WriteResult") -> "WriteResult":
  26. num_rows = sum(wr.num_rows for wr in wrs)
  27. size_bytes = sum(wr.size_bytes for wr in wrs)
  28. write_returns = list(itertools.chain(*[wr.write_returns for wr in wrs]))
  29. return WriteResult(
  30. num_rows=num_rows,
  31. size_bytes=size_bytes,
  32. write_returns=write_returns,
  33. )
  34. @DeveloperAPI
  35. class Datasink(Generic[WriteReturnType]):
  36. """Interface for defining write-related logic.
  37. If you want to write data to something that isn't built-in, subclass this class
  38. and call :meth:`~ray.data.Dataset.write_datasink`.
  39. """
  40. def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
  41. """Callback for when a write job starts.
  42. Use this method to perform setup for write tasks. For example, creating a
  43. staging bucket in S3.
  44. This is called on the driver when the first input bundle is ready, just
  45. before write tasks are submitted. The schema is extracted from the first
  46. input bundle, enabling schema-dependent initialization.
  47. Args:
  48. schema: The PyArrow schema of the data being written. This is
  49. automatically extracted from the first input bundle. May be None
  50. if the input data has no schema.
  51. """
  52. pass
  53. def write(
  54. self,
  55. blocks: Iterable[Block],
  56. ctx: TaskContext,
  57. ) -> WriteReturnType:
  58. """Write blocks. This is used by a single write task.
  59. Args:
  60. blocks: Generator of data blocks.
  61. ctx: ``TaskContext`` for the write task.
  62. Returns:
  63. Result of this write task. When the entire write operator finishes,
  64. All returned values will be passed as `WriteResult.write_returns`
  65. to `Datasink.on_write_complete`.
  66. """
  67. raise NotImplementedError
  68. def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
  69. """Callback for when a write job completes.
  70. This can be used to `commit` a write output. This method must
  71. succeed prior to ``write_datasink()`` returning to the user. If this
  72. method fails, then ``on_write_failed()`` is called.
  73. Args:
  74. write_result: Aggregated result of the
  75. Write operator, containing write results and stats.
  76. """
  77. pass
  78. def on_write_failed(self, error: Exception) -> None:
  79. """Callback for when a write job fails.
  80. This is called on a best-effort basis on write failures.
  81. Args:
  82. error: The first error encountered.
  83. """
  84. pass
  85. def get_name(self) -> str:
  86. """Return a human-readable name for this datasink.
  87. This is used as the names of the write tasks.
  88. """
  89. name = type(self).__name__
  90. datasink_suffix = "Datasink"
  91. if name.startswith("_"):
  92. name = name[1:]
  93. if name.endswith(datasink_suffix):
  94. name = name[: -len(datasink_suffix)]
  95. return name
  96. @property
  97. def supports_distributed_writes(self) -> bool:
  98. """If ``False``, only launch write tasks on the driver's node."""
  99. return True
  100. @property
  101. def min_rows_per_write(self) -> Optional[int]:
  102. """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.
  103. If ``None``, Ray Data passes a system-chosen number of rows.
  104. """
  105. return None
  106. @DeveloperAPI
  107. class DummyOutputDatasink(Datasink[None]):
  108. """An example implementation of a writable datasource for testing.
  109. Examples:
  110. >>> import ray
  111. >>> from ray.data.datasource import DummyOutputDatasink
  112. >>> output = DummyOutputDatasink()
  113. >>> ray.data.range(10).write_datasink(output)
  114. >>> assert output.num_ok == 1
  115. """
  116. def __init__(self):
  117. ctx = ray.data.DataContext.get_current()
  118. # Setup a dummy actor to send the data. In a real datasource, write
  119. # tasks would send data to an external system instead of a Ray actor.
  120. @ray.remote(scheduling_strategy=ctx.scheduling_strategy)
  121. class DataSink:
  122. def __init__(self):
  123. self.rows_written = 0
  124. self.enabled = True
  125. def write(self, block: Block) -> None:
  126. block = BlockAccessor.for_block(block)
  127. self.rows_written += block.num_rows()
  128. def get_rows_written(self):
  129. return self.rows_written
  130. self.data_sink = DataSink.remote()
  131. self.num_ok = 0
  132. self.num_failed = 0
  133. self.enabled = True
  134. def write(
  135. self,
  136. blocks: Iterable[Block],
  137. ctx: TaskContext,
  138. ) -> None:
  139. tasks = []
  140. if not self.enabled:
  141. raise ValueError("disabled")
  142. for b in blocks:
  143. tasks.append(self.data_sink.write.remote(b))
  144. ray.get(tasks)
  145. def on_write_complete(self, write_result: WriteResult[None]):
  146. self.num_ok += 1
  147. def on_write_failed(self, error: Exception) -> None:
  148. self.num_failed += 1
  149. def _gen_datasink_write_result(
  150. write_result_blocks: List[Block],
  151. ) -> WriteResult:
  152. import pandas as pd
  153. assert all(
  154. isinstance(block, pd.DataFrame) and len(block) == 1
  155. for block in write_result_blocks
  156. )
  157. total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
  158. total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)
  159. write_returns = [result["write_return"][0] for result in write_result_blocks]
  160. return WriteResult(total_num_rows, total_size_bytes, write_returns)