util.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import logging
  2. from typing import Iterable
  3. from ray.data._internal.execution.interfaces.task_context import TaskContext
  4. from ray.data.block import Block, BlockAccessor, DataBatch
  5. from ray.data.checkpoint.interfaces import (
  6. CheckpointConfig,
  7. )
  8. logger = logging.getLogger(__name__)
  9. # Checkpoint keyword argument name
  10. CHECKPOINTED_IDS_KWARG_NAME = "checkpointed_ids"
  11. def filter_checkpointed_rows_for_blocks(
  12. blocks: Iterable[Block],
  13. task_context: TaskContext,
  14. checkpoint_config: CheckpointConfig,
  15. ) -> Iterable[Block]:
  16. """For each block, filter rows that have already been checkpointed
  17. and yield the resulting block."""
  18. from ray.data.checkpoint.checkpoint_filter import (
  19. BatchBasedCheckpointFilter,
  20. )
  21. ckpt_filter = BatchBasedCheckpointFilter(checkpoint_config)
  22. checkpointed_ids = task_context.kwargs[CHECKPOINTED_IDS_KWARG_NAME]
  23. def filter_fn(block: Block) -> Block:
  24. return ckpt_filter.filter_rows_for_block(
  25. block=block,
  26. checkpointed_ids=checkpointed_ids,
  27. )
  28. for block in blocks:
  29. filtered_block = filter_fn(block)
  30. ba = BlockAccessor.for_block(filtered_block)
  31. if ba.num_rows() > 0:
  32. yield filtered_block
  33. def filter_checkpointed_rows_for_batches(
  34. batches: Iterable[DataBatch],
  35. task_context: TaskContext,
  36. checkpoint_config: CheckpointConfig,
  37. ) -> Iterable[DataBatch]:
  38. """For each batch, filter rows that have already been checkpointed
  39. and yield the resulting batches."""
  40. from ray.data.checkpoint.checkpoint_filter import (
  41. BatchBasedCheckpointFilter,
  42. )
  43. ckpt_filter = BatchBasedCheckpointFilter(checkpoint_config)
  44. checkpointed_ids = task_context.kwargs[CHECKPOINTED_IDS_KWARG_NAME]
  45. def filter_fn(batch: DataBatch) -> DataBatch:
  46. return ckpt_filter.filter_rows_for_batch(
  47. batch=batch,
  48. checkpointed_ids=checkpointed_ids,
  49. )
  50. for batch in batches:
  51. filtered_batch = filter_fn(batch)
  52. yield filtered_batch