load_checkpoint_callback.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import logging
  2. from typing import Optional
  3. from ray.data._internal.execution.execution_callback import (
  4. ExecutionCallback,
  5. remove_execution_callback,
  6. )
  7. from ray.data._internal.execution.streaming_executor import StreamingExecutor
  8. from ray.data.block import Block
  9. from ray.data.checkpoint import CheckpointConfig
  10. from ray.data.checkpoint.checkpoint_filter import BatchBasedCheckpointFilter
  11. from ray.types import ObjectRef
  12. logger = logging.getLogger(__name__)
  13. class LoadCheckpointCallback(ExecutionCallback):
  14. """ExecutionCallback that handles checkpoints."""
  15. def __init__(self, config: CheckpointConfig):
  16. assert config is not None
  17. self._config = config
  18. self._ckpt_filter = self._create_checkpoint_filter(config)
  19. self._checkpoint_ref: Optional[ObjectRef[Block]] = None
  20. def _create_checkpoint_filter(
  21. self, config: CheckpointConfig
  22. ) -> BatchBasedCheckpointFilter:
  23. """Factory method to create the checkpoint filter.
  24. Subclasses can override this to use a different filter implementation.
  25. """
  26. return BatchBasedCheckpointFilter(config)
  27. def before_execution_starts(self, executor: StreamingExecutor):
  28. assert self._config is executor._data_context.checkpoint_config
  29. # Load checkpoint data before execution starts.
  30. self._checkpoint_ref = self._ckpt_filter.load_checkpoint()
  31. def after_execution_succeeds(self, executor: StreamingExecutor):
  32. assert self._config is executor._data_context.checkpoint_config
  33. # Remove the callback from the DataContext.
  34. remove_execution_callback(self, executor._data_context)
  35. # Delete checkpoint data.
  36. try:
  37. if self._config.delete_checkpoint_on_success:
  38. self._ckpt_filter.delete_checkpoint()
  39. except Exception:
  40. logger.warning("Failed to delete checkpoint data.", exc_info=True)
  41. def after_execution_fails(self, executor: StreamingExecutor, error: Exception):
  42. assert self._config is executor._data_context.checkpoint_config
  43. # Remove the callback from the DataContext.
  44. remove_execution_callback(self, executor._data_context)
  45. def load_checkpoint(self) -> ObjectRef[Block]:
  46. assert self._checkpoint_ref is not None
  47. return self._checkpoint_ref