| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import logging
- import os
- import uuid
- from abc import abstractmethod
- from pyarrow import parquet as pq
- from ray.data._internal.util import call_with_retry
- from ray.data.block import BlockAccessor
- from ray.data.checkpoint import CheckpointBackend, CheckpointConfig
- from ray.data.context import DataContext
- from ray.data.datasource.path_util import _unwrap_protocol
- logger = logging.getLogger(__name__)
- class CheckpointWriter:
- """Abstract class which defines the interface for writing row-level
- checkpoints based on varying backends.
- Subclasses must implement `.write_block_checkpoint()`."""
- def __init__(self, config: CheckpointConfig):
- self.ckpt_config = config
- self.checkpoint_path_unwrapped = _unwrap_protocol(
- self.ckpt_config.checkpoint_path
- )
- self.id_col = self.ckpt_config.id_column
- self.filesystem = self.ckpt_config.filesystem
- self.write_num_threads = self.ckpt_config.write_num_threads
- @abstractmethod
- def write_block_checkpoint(self, block: BlockAccessor):
- """Write a checkpoint for all rows in a single block to the checkpoint
- output directory given by `self.checkpoint_path`.
- Subclasses of `CheckpointWriter` must implement this method."""
- ...
- @staticmethod
- def create(config: CheckpointConfig) -> "CheckpointWriter":
- """Factory method to create a `CheckpointWriter` based on the
- provided `CheckpointConfig`."""
- backend = config.backend
- if backend in [
- CheckpointBackend.CLOUD_OBJECT_STORAGE,
- CheckpointBackend.FILE_STORAGE,
- ]:
- return BatchBasedCheckpointWriter(config)
- raise NotImplementedError(f"Backend {backend} not implemented")
- class BatchBasedCheckpointWriter(CheckpointWriter):
- """CheckpointWriter for batch-based backends."""
- def __init__(self, config: CheckpointConfig):
- super().__init__(config)
- self.filesystem.create_dir(self.checkpoint_path_unwrapped, recursive=True)
- def write_block_checkpoint(self, block: BlockAccessor):
- """Write a checkpoint for all rows in a single block to the checkpoint
- output directory given by `self.checkpoint_path`.
- Subclasses of `CheckpointWriter` must implement this method."""
- if block.num_rows() == 0:
- return
- file_name = f"{uuid.uuid4()}.parquet"
- ckpt_file_path = os.path.join(self.checkpoint_path_unwrapped, file_name)
- checkpoint_ids_block = block.select(columns=[self.id_col])
- # `pyarrow.parquet.write_parquet` requires a PyArrow table. It errors if the block is
- # a pandas DataFrame.
- checkpoint_ids_table = BlockAccessor.for_block(checkpoint_ids_block).to_arrow()
- def _write():
- pq.write_table(
- checkpoint_ids_table,
- ckpt_file_path,
- filesystem=self.filesystem,
- )
- try:
- return call_with_retry(
- _write,
- description=f"Write checkpoint file: {file_name}",
- match=DataContext.get_current().retried_io_errors,
- )
- except Exception:
- logger.exception(f"Checkpoint write failed: {file_name}")
- raise
|