| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import logging
- from typing import Optional
- from ray.data._internal.execution.execution_callback import (
- ExecutionCallback,
- remove_execution_callback,
- )
- from ray.data._internal.execution.streaming_executor import StreamingExecutor
- from ray.data.block import Block
- from ray.data.checkpoint import CheckpointConfig
- from ray.data.checkpoint.checkpoint_filter import BatchBasedCheckpointFilter
- from ray.types import ObjectRef
- logger = logging.getLogger(__name__)
- class LoadCheckpointCallback(ExecutionCallback):
- """ExecutionCallback that handles checkpoints."""
- def __init__(self, config: CheckpointConfig):
- assert config is not None
- self._config = config
- self._ckpt_filter = self._create_checkpoint_filter(config)
- self._checkpoint_ref: Optional[ObjectRef[Block]] = None
- def _create_checkpoint_filter(
- self, config: CheckpointConfig
- ) -> BatchBasedCheckpointFilter:
- """Factory method to create the checkpoint filter.
- Subclasses can override this to use a different filter implementation.
- """
- return BatchBasedCheckpointFilter(config)
- def before_execution_starts(self, executor: StreamingExecutor):
- assert self._config is executor._data_context.checkpoint_config
- # Load checkpoint data before execution starts.
- self._checkpoint_ref = self._ckpt_filter.load_checkpoint()
- def after_execution_succeeds(self, executor: StreamingExecutor):
- assert self._config is executor._data_context.checkpoint_config
- # Remove the callback from the DataContext.
- remove_execution_callback(self, executor._data_context)
- # Delete checkpoint data.
- try:
- if self._config.delete_checkpoint_on_success:
- self._ckpt_filter.delete_checkpoint()
- except Exception:
- logger.warning("Failed to delete checkpoint data.", exc_info=True)
- def after_execution_fails(self, executor: StreamingExecutor, error: Exception):
- assert self._config is executor._data_context.checkpoint_config
- # Remove the callback from the DataContext.
- remove_execution_callback(self, executor._data_context)
- def load_checkpoint(self) -> ObjectRef[Block]:
- assert self._checkpoint_ref is not None
- return self._checkpoint_ref
|