| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import itertools
- import logging
- from dataclasses import dataclass
- from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar
- import ray
- from ray.data._internal.execution.interfaces import TaskContext
- from ray.data.block import Block, BlockAccessor
- from ray.util.annotations import DeveloperAPI
- if TYPE_CHECKING:
- import pyarrow as pa
- logger = logging.getLogger(__name__)
- WriteReturnType = TypeVar("WriteReturnType")
- """Generic type for the return value of `Datasink.write`."""
- @dataclass
- @DeveloperAPI
- class WriteResult(Generic[WriteReturnType]):
- """Aggregated result of the Datasink write operations."""
- # Total number of written rows.
- num_rows: int
- # Total size in bytes of written data.
- size_bytes: int
- # All returned values of `Datasink.write`.
- write_returns: List[WriteReturnType]
- @classmethod
- def combine(cls, *wrs: "WriteResult") -> "WriteResult":
- num_rows = sum(wr.num_rows for wr in wrs)
- size_bytes = sum(wr.size_bytes for wr in wrs)
- write_returns = list(itertools.chain(*[wr.write_returns for wr in wrs]))
- return WriteResult(
- num_rows=num_rows,
- size_bytes=size_bytes,
- write_returns=write_returns,
- )
- @DeveloperAPI
- class Datasink(Generic[WriteReturnType]):
- """Interface for defining write-related logic.
- If you want to write data to something that isn't built-in, subclass this class
- and call :meth:`~ray.data.Dataset.write_datasink`.
- """
- def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
- """Callback for when a write job starts.
- Use this method to perform setup for write tasks. For example, creating a
- staging bucket in S3.
- This is called on the driver when the first input bundle is ready, just
- before write tasks are submitted. The schema is extracted from the first
- input bundle, enabling schema-dependent initialization.
- Args:
- schema: The PyArrow schema of the data being written. This is
- automatically extracted from the first input bundle. May be None
- if the input data has no schema.
- """
- pass
- def write(
- self,
- blocks: Iterable[Block],
- ctx: TaskContext,
- ) -> WriteReturnType:
- """Write blocks. This is used by a single write task.
- Args:
- blocks: Generator of data blocks.
- ctx: ``TaskContext`` for the write task.
- Returns:
- Result of this write task. When the entire write operator finishes,
- All returned values will be passed as `WriteResult.write_returns`
- to `Datasink.on_write_complete`.
- """
- raise NotImplementedError
- def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
- """Callback for when a write job completes.
- This can be used to `commit` a write output. This method must
- succeed prior to ``write_datasink()`` returning to the user. If this
- method fails, then ``on_write_failed()`` is called.
- Args:
- write_result: Aggregated result of the
- Write operator, containing write results and stats.
- """
- pass
- def on_write_failed(self, error: Exception) -> None:
- """Callback for when a write job fails.
- This is called on a best-effort basis on write failures.
- Args:
- error: The first error encountered.
- """
- pass
- def get_name(self) -> str:
- """Return a human-readable name for this datasink.
- This is used as the names of the write tasks.
- """
- name = type(self).__name__
- datasink_suffix = "Datasink"
- if name.startswith("_"):
- name = name[1:]
- if name.endswith(datasink_suffix):
- name = name[: -len(datasink_suffix)]
- return name
- @property
- def supports_distributed_writes(self) -> bool:
- """If ``False``, only launch write tasks on the driver's node."""
- return True
- @property
- def min_rows_per_write(self) -> Optional[int]:
- """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.
- If ``None``, Ray Data passes a system-chosen number of rows.
- """
- return None
- @DeveloperAPI
- class DummyOutputDatasink(Datasink[None]):
- """An example implementation of a writable datasource for testing.
- Examples:
- >>> import ray
- >>> from ray.data.datasource import DummyOutputDatasink
- >>> output = DummyOutputDatasink()
- >>> ray.data.range(10).write_datasink(output)
- >>> assert output.num_ok == 1
- """
- def __init__(self):
- ctx = ray.data.DataContext.get_current()
- # Setup a dummy actor to send the data. In a real datasource, write
- # tasks would send data to an external system instead of a Ray actor.
- @ray.remote(scheduling_strategy=ctx.scheduling_strategy)
- class DataSink:
- def __init__(self):
- self.rows_written = 0
- self.enabled = True
- def write(self, block: Block) -> None:
- block = BlockAccessor.for_block(block)
- self.rows_written += block.num_rows()
- def get_rows_written(self):
- return self.rows_written
- self.data_sink = DataSink.remote()
- self.num_ok = 0
- self.num_failed = 0
- self.enabled = True
- def write(
- self,
- blocks: Iterable[Block],
- ctx: TaskContext,
- ) -> None:
- tasks = []
- if not self.enabled:
- raise ValueError("disabled")
- for b in blocks:
- tasks.append(self.data_sink.write.remote(b))
- ray.get(tasks)
- def on_write_complete(self, write_result: WriteResult[None]):
- self.num_ok += 1
- def on_write_failed(self, error: Exception) -> None:
- self.num_failed += 1
- def _gen_datasink_write_result(
- write_result_blocks: List[Block],
- ) -> WriteResult:
- import pandas as pd
- assert all(
- isinstance(block, pd.DataFrame) and len(block) == 1
- for block in write_result_blocks
- )
- total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
- total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)
- write_returns = [result["write_return"][0] for result in write_result_blocks]
- return WriteResult(total_num_rows, total_size_bytes, write_returns)
|