checkpoint_writer.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import logging
  2. import os
  3. import uuid
  4. from abc import abstractmethod
  5. from pyarrow import parquet as pq
  6. from ray.data._internal.util import call_with_retry
  7. from ray.data.block import BlockAccessor
  8. from ray.data.checkpoint import CheckpointBackend, CheckpointConfig
  9. from ray.data.context import DataContext
  10. from ray.data.datasource.path_util import _unwrap_protocol
  11. logger = logging.getLogger(__name__)
  12. class CheckpointWriter:
  13. """Abstract class which defines the interface for writing row-level
  14. checkpoints based on varying backends.
  15. Subclasses must implement `.write_block_checkpoint()`."""
  16. def __init__(self, config: CheckpointConfig):
  17. self.ckpt_config = config
  18. self.checkpoint_path_unwrapped = _unwrap_protocol(
  19. self.ckpt_config.checkpoint_path
  20. )
  21. self.id_col = self.ckpt_config.id_column
  22. self.filesystem = self.ckpt_config.filesystem
  23. self.write_num_threads = self.ckpt_config.write_num_threads
  24. @abstractmethod
  25. def write_block_checkpoint(self, block: BlockAccessor):
  26. """Write a checkpoint for all rows in a single block to the checkpoint
  27. output directory given by `self.checkpoint_path`.
  28. Subclasses of `CheckpointWriter` must implement this method."""
  29. ...
  30. @staticmethod
  31. def create(config: CheckpointConfig) -> "CheckpointWriter":
  32. """Factory method to create a `CheckpointWriter` based on the
  33. provided `CheckpointConfig`."""
  34. backend = config.backend
  35. if backend in [
  36. CheckpointBackend.CLOUD_OBJECT_STORAGE,
  37. CheckpointBackend.FILE_STORAGE,
  38. ]:
  39. return BatchBasedCheckpointWriter(config)
  40. raise NotImplementedError(f"Backend {backend} not implemented")
  41. class BatchBasedCheckpointWriter(CheckpointWriter):
  42. """CheckpointWriter for batch-based backends."""
  43. def __init__(self, config: CheckpointConfig):
  44. super().__init__(config)
  45. self.filesystem.create_dir(self.checkpoint_path_unwrapped, recursive=True)
  46. def write_block_checkpoint(self, block: BlockAccessor):
  47. """Write a checkpoint for all rows in a single block to the checkpoint
  48. output directory given by `self.checkpoint_path`.
  49. Subclasses of `CheckpointWriter` must implement this method."""
  50. if block.num_rows() == 0:
  51. return
  52. file_name = f"{uuid.uuid4()}.parquet"
  53. ckpt_file_path = os.path.join(self.checkpoint_path_unwrapped, file_name)
  54. checkpoint_ids_block = block.select(columns=[self.id_col])
  55. # `pyarrow.parquet.write_parquet` requires a PyArrow table. It errors if the block is
  56. # a pandas DataFrame.
  57. checkpoint_ids_table = BlockAccessor.for_block(checkpoint_ids_block).to_arrow()
  58. def _write():
  59. pq.write_table(
  60. checkpoint_ids_table,
  61. ckpt_file_path,
  62. filesystem=self.filesystem,
  63. )
  64. try:
  65. return call_with_retry(
  66. _write,
  67. description=f"Write checkpoint file: {file_name}",
  68. match=DataContext.get_current().retried_io_errors,
  69. )
  70. except Exception:
  71. logger.exception(f"Checkpoint write failed: {file_name}")
  72. raise