split.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import itertools
  2. import logging
  3. from typing import Iterable, List, Tuple, Union
  4. import ray
  5. from ray.data._internal.memory_tracing import trace_deallocation
  6. from ray.data._internal.remote_fn import cached_remote_fn
  7. from ray.data.block import (
  8. Block,
  9. BlockAccessor,
  10. BlockExecStats,
  11. BlockMetadata,
  12. BlockPartition,
  13. )
  14. from ray.types import ObjectRef
  15. logger = logging.getLogger(__name__)
  16. def _calculate_blocks_rows(
  17. blocks_with_metadata: BlockPartition,
  18. ) -> List[int]:
  19. """Calculate the number of rows for a list of blocks with metadata."""
  20. get_num_rows = cached_remote_fn(_get_num_rows)
  21. block_rows = []
  22. for block, metadata in blocks_with_metadata:
  23. if metadata.num_rows is None:
  24. # Need to fetch number of rows.
  25. num_rows = ray.get(get_num_rows.remote(block))
  26. metadata.num_rows = num_rows
  27. else:
  28. num_rows = metadata.num_rows
  29. block_rows.append(num_rows)
  30. return block_rows
  31. def _generate_valid_indices(
  32. num_rows_per_block: List[int],
  33. split_indices: List[int],
  34. ) -> List[int]:
  35. """Generate valid split indices by apply min(index, total_num_rows)
  36. to every index."""
  37. total_rows = sum(num_rows_per_block)
  38. return [min(index, total_rows) for index in split_indices]
  39. def _generate_per_block_split_indices(
  40. num_rows_per_block: List[int],
  41. split_indices: List[int],
  42. ) -> List[List[int]]:
  43. """Given num rows per block and valid split indices, generate per block split indices.
  44. Args:
  45. num_rows_per_block: num of rows per block.
  46. split_indices: The (global) indices at which to split the blocks.
  47. Returns:
  48. Per block split indices indicates each input block's split point(s).
  49. """
  50. # for each split index, we iterate though the currnet input block
  51. # to see if the index falls into this block. if the index
  52. # falls into this block, we push it back to the current block's
  53. # split indices. Otherwise, we move on to the next block.
  54. per_block_split_indices = []
  55. current_input_block_id = 0
  56. current_block_split_indices = []
  57. current_block_global_offset = 0
  58. current_index_id = 0
  59. while current_index_id < len(split_indices):
  60. split_index = split_indices[current_index_id]
  61. current_block_row = num_rows_per_block[current_input_block_id]
  62. if split_index - current_block_global_offset <= current_block_row:
  63. current_block_split_indices.append(
  64. split_index - current_block_global_offset
  65. )
  66. current_index_id += 1
  67. continue
  68. per_block_split_indices.append(current_block_split_indices)
  69. current_block_split_indices = []
  70. current_block_global_offset += num_rows_per_block[current_input_block_id]
  71. current_input_block_id += 1
  72. # we might finished all the indices but there are still blocks left, also
  73. # current_block_split_indices might not be added yet.
  74. while len(per_block_split_indices) < len(num_rows_per_block):
  75. per_block_split_indices.append(current_block_split_indices)
  76. current_block_split_indices = []
  77. return per_block_split_indices
  78. def _split_single_block(
  79. block_id: int,
  80. block: Block,
  81. meta: BlockMetadata,
  82. split_indices: List[int],
  83. ) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]:
  84. """Split the provided block at the given indices.
  85. Args:
  86. block_id: the id of this block in the block list.
  87. block: block to be split.
  88. meta: metadata of the block, we expect meta.num is valid.
  89. split_indices: the indices where the block should be split.
  90. Returns:
  91. returns block_id, split blocks metadata, and a list of blocks
  92. in the following form. We return blocks in this way
  93. so that the owner of blocks could be the caller(driver)
  94. instead of worker itself.
  95. Tuple(block_id, split_blocks_meta), block0, block1 ...
  96. """
  97. split_meta = []
  98. split_blocks = []
  99. block_accessor = BlockAccessor.for_block(block)
  100. prev_index = 0
  101. # append one more entry at the last so we don't
  102. # need handle empty edge case.
  103. split_indices.append(meta.num_rows)
  104. for index in split_indices:
  105. logger.debug(f"slicing block {prev_index}:{index}")
  106. stats = BlockExecStats.builder()
  107. split_block = block_accessor.slice(prev_index, index)
  108. accessor = BlockAccessor.for_block(split_block)
  109. _meta = BlockMetadata(
  110. num_rows=accessor.num_rows(),
  111. size_bytes=accessor.size_bytes(),
  112. input_files=meta.input_files,
  113. exec_stats=stats.build(),
  114. )
  115. split_meta.append(_meta)
  116. split_blocks.append(split_block)
  117. prev_index = index
  118. results = [(block_id, split_meta)]
  119. results.extend(split_blocks)
  120. return tuple(results)
  121. def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
  122. """drop split indices that creates empty block split. This could happen when there
  123. are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
  124. (end of the block).
  125. """
  126. prev_index = -1
  127. optimized_indices = []
  128. for index in block_split_indices:
  129. if index == 0 or index == num_rows:
  130. continue
  131. if index == prev_index:
  132. continue
  133. optimized_indices.append(index)
  134. prev_index = index
  135. return optimized_indices
  136. def _split_all_blocks(
  137. blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
  138. per_block_split_indices: List[List[int]],
  139. owned_by_consumer: bool,
  140. ) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
  141. """Split all the input blocks based on the split indices"""
  142. split_single_block = cached_remote_fn(_split_single_block)
  143. all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)
  144. per_block_split_metadata_futures = []
  145. per_block_split_block_refs = []
  146. # tracking splitted blocks for gc.
  147. blocks_splitted = []
  148. for block_id, block_split_indices in enumerate(per_block_split_indices):
  149. (block_ref, meta) = blocks_with_metadata[block_id]
  150. block_row = meta.num_rows
  151. block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
  152. if len(block_split_indices) == 0:
  153. # optimization: if no split is needed, we just need to add it to the
  154. # result
  155. all_blocks_split_results[block_id] = [(block_ref, meta)]
  156. else:
  157. # otherwise call split remote function.
  158. object_refs = split_single_block.options(
  159. scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices)
  160. ).remote(
  161. block_id,
  162. block_ref,
  163. meta,
  164. block_split_indices,
  165. )
  166. per_block_split_metadata_futures.append(object_refs[0])
  167. per_block_split_block_refs.append(object_refs[1:])
  168. blocks_splitted.append(block_ref)
  169. if per_block_split_metadata_futures:
  170. # only get metadata.
  171. per_block_split_metadata = ray.get(per_block_split_metadata_futures)
  172. for (block_id, meta), block_refs in zip(
  173. per_block_split_metadata, per_block_split_block_refs
  174. ):
  175. assert len(meta) == len(block_refs)
  176. all_blocks_split_results[block_id] = zip(block_refs, meta)
  177. # We make a copy for the blocks that have been splitted, so the input blocks
  178. # can be cleared if they are owned by consumer (consumer-owned blocks will
  179. # only be consumed by the owner).
  180. if owned_by_consumer:
  181. for b in blocks_splitted:
  182. trace_deallocation(b, "split._split_all_blocks")
  183. else:
  184. for b in blocks_splitted:
  185. trace_deallocation(b, "split._split_all_blocks", free=False)
  186. return itertools.chain.from_iterable(all_blocks_split_results)
  187. def _generate_global_split_results(
  188. all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
  189. global_split_sizes: List[int],
  190. ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
  191. """Reassemble per block's split result into final split result."""
  192. result_blocks = []
  193. result_metas = []
  194. current_blocks = []
  195. current_meta = []
  196. current_split_size = 0
  197. current_split_id = 0
  198. while current_split_id < len(global_split_sizes):
  199. if current_split_size >= global_split_sizes[current_split_id]:
  200. assert current_split_size == global_split_sizes[current_split_id]
  201. result_blocks.append(current_blocks)
  202. result_metas.append(current_meta)
  203. current_blocks = []
  204. current_meta = []
  205. current_split_size = 0
  206. current_split_id += 1
  207. else:
  208. (block_ref, meta) = next(all_blocks_split_results)
  209. current_blocks.append(block_ref)
  210. current_meta.append(meta)
  211. current_split_size += meta.num_rows
  212. return result_blocks, result_metas
  213. def _split_at_indices(
  214. blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
  215. indices: List[int],
  216. owned_by_consumer: bool = True,
  217. block_rows: List[int] = None,
  218. ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
  219. """Split blocks at the provided indices.
  220. Args:
  221. blocks_with_metadata: Block futures to split, including the associated metadata.
  222. indices: The (global) indices at which to split the blocks.
  223. owned_by_consumer: Whether the provided blocks are owned by the consumer.
  224. block_rows: The number of rows for each block, in case it has already been
  225. computed.
  226. Returns:
  227. The block split futures and their metadata. If an index split is empty, the
  228. corresponding block split will be empty .
  229. """
  230. # We implement the split in 3 phases.
  231. # phase 1: calculate the per block split indices.
  232. blocks_with_metadata = list(blocks_with_metadata)
  233. if len(blocks_with_metadata) == 0:
  234. return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
  235. if block_rows is None:
  236. block_rows = _calculate_blocks_rows(blocks_with_metadata)
  237. valid_indices = _generate_valid_indices(block_rows, indices)
  238. per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
  239. block_rows, valid_indices
  240. )
  241. # phase 2: split each block based on the indices from previous step.
  242. all_blocks_split_results: Iterable[
  243. Tuple[ObjectRef[Block], BlockMetadata]
  244. ] = _split_all_blocks(
  245. blocks_with_metadata, per_block_split_indices, owned_by_consumer
  246. )
  247. # phase 3: generate the final split.
  248. # first calculate the size for each split.
  249. helper = [0] + valid_indices + [sum(block_rows)]
  250. split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]
  251. return _generate_global_split_results(all_blocks_split_results, split_sizes)
  252. def _get_num_rows(block: Block) -> int:
  253. """Get the number of rows contained in the provided block."""
  254. return BlockAccessor.for_block(block).num_rows()