table_block.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import collections
  2. import heapq
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. Dict,
  7. Iterator,
  8. List,
  9. Mapping,
  10. Optional,
  11. Sequence,
  12. Tuple,
  13. TypeVar,
  14. Union,
  15. )
  16. from ray._private.ray_constants import env_integer
  17. from ray.data._internal.block_builder import BlockBuilder
  18. from ray.data._internal.size_estimator import SizeEstimator
  19. from ray.data._internal.util import (
  20. NULL_SENTINEL,
  21. find_partition_index,
  22. is_nan,
  23. keys_equal,
  24. )
  25. from ray.data.block import (
  26. Block,
  27. BlockAccessor,
  28. BlockColumnAccessor,
  29. BlockExecStats,
  30. BlockMetadataWithSchema,
  31. BlockType,
  32. KeyType,
  33. U,
  34. )
  35. from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE
  36. if TYPE_CHECKING:
  37. from ray.data._internal.planner.exchange.sort_task_spec import SortKey
  38. from ray.data.aggregate import AggregateFn
  39. T = TypeVar("T")
  40. # The max size of Python tuples to buffer before compacting them into a
  41. # table in the BlockBuilder.
  42. MAX_UNCOMPACTED_SIZE_BYTES = env_integer(
  43. "RAY_DATA_MAX_UNCOMPACTED_SIZE_BYTES", DEFAULT_TARGET_MAX_BLOCK_SIZE
  44. )
  45. class TableBlockBuilder(BlockBuilder):
  46. def __init__(self, block_type):
  47. # The set of uncompacted Python values buffered.
  48. self._columns = collections.defaultdict(list)
  49. # The set of compacted tables we have built so far.
  50. self._tables: List[Any] = []
  51. # Cursor into tables indicating up to which table we've accumulated table sizes.
  52. # This is used to defer table size calculation, which can be expensive for e.g.
  53. # Pandas DataFrames.
  54. # This cursor points to the first table for which we haven't accumulated a table
  55. # size.
  56. self._tables_size_cursor = 0
  57. # Accumulated table sizes, up to the table in _tables pointed to by
  58. # _tables_size_cursor.
  59. self._tables_size_bytes = 0
  60. # Size estimator for un-compacted table values.
  61. self._uncompacted_size = SizeEstimator()
  62. self._num_rows = 0
  63. self._num_uncompacted_rows = 0
  64. self._num_compactions = 0
  65. self._block_type = block_type
  66. def add(self, item: Union[dict, Mapping]) -> None:
  67. if hasattr(item, "as_pydict"):
  68. item = item.as_pydict()
  69. if not isinstance(item, collections.abc.Mapping):
  70. raise ValueError(
  71. "Returned elements of an TableBlock must be of type `dict`, "
  72. "got {} (type {}).".format(item, type(item))
  73. )
  74. # Fill in missing columns with None.
  75. for column_name in item:
  76. if column_name not in self._columns:
  77. self._columns[column_name] = [None] * self._num_uncompacted_rows
  78. for column_name in self._columns:
  79. value = item.get(column_name)
  80. self._columns[column_name].append(value)
  81. self._num_rows += 1
  82. self._num_uncompacted_rows += 1
  83. self._compact_if_needed()
  84. self._uncompacted_size.add(item)
  85. def add_block(self, block: Any) -> None:
  86. if not isinstance(block, self._block_type):
  87. raise TypeError(
  88. f"Got a block of type {type(block)}, expected {self._block_type}."
  89. "If you are mapping a function, ensure it returns an "
  90. "object with the expected type. Block:\n"
  91. f"{block}"
  92. )
  93. accessor = BlockAccessor.for_block(block)
  94. self._tables.append(block)
  95. self._num_rows += accessor.num_rows()
  96. @staticmethod
  97. def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
  98. raise NotImplementedError
  99. @staticmethod
  100. def _combine_tables(tables: List[Block]) -> Block:
  101. raise NotImplementedError
  102. @staticmethod
  103. def _empty_table() -> Any:
  104. raise NotImplementedError
  105. @staticmethod
  106. def _concat_would_copy() -> bool:
  107. raise NotImplementedError
  108. def will_build_yield_copy(self) -> bool:
  109. if self._columns:
  110. # Building a table from a dict of list columns always creates a copy.
  111. return True
  112. return self._concat_would_copy() and len(self._tables) > 1
  113. def build(self) -> Block:
  114. if self._columns:
  115. tables = [self._table_from_pydict(self._columns)]
  116. else:
  117. tables = []
  118. tables.extend(self._tables)
  119. if len(tables) == 0:
  120. return self._empty_table()
  121. else:
  122. return self._combine_tables(tables)
  123. def num_rows(self) -> int:
  124. return self._num_rows
  125. def num_blocks(self) -> int:
  126. return len(self._tables)
  127. def get_estimated_memory_usage(self) -> int:
  128. if self._num_rows == 0:
  129. return 0
  130. for table in self._tables[self._tables_size_cursor :]:
  131. self._tables_size_bytes += BlockAccessor.for_block(table).size_bytes()
  132. self._tables_size_cursor = len(self._tables)
  133. return self._tables_size_bytes + self._uncompacted_size.size_bytes()
  134. def _compact_if_needed(self) -> None:
  135. assert self._columns
  136. if self._uncompacted_size.size_bytes() < MAX_UNCOMPACTED_SIZE_BYTES:
  137. return
  138. block = self._table_from_pydict(self._columns)
  139. self.add_block(block)
  140. self._uncompacted_size = SizeEstimator()
  141. self._columns.clear()
  142. self._num_compactions += 1
  143. self._num_uncompacted_rows = 0
  144. class TableBlockAccessor(BlockAccessor):
  145. def __init__(self, table: Any):
  146. self._table = table
  147. @staticmethod
  148. def _munge_conflict(name, count):
  149. return f"{name}_{count + 1}"
  150. def to_default(self) -> Block:
  151. # Always promote Arrow blocks to pandas for consistency, since
  152. # we lazily convert pandas->Arrow internally for efficiency.
  153. default = self.to_pandas()
  154. return default
  155. def column_names(self) -> List[str]:
  156. raise NotImplementedError
  157. def fill_column(self, name: str, value: Any) -> Block:
  158. raise NotImplementedError
  159. def to_block(self) -> Block:
  160. return self._table
  161. def _zip(self, acc: BlockAccessor) -> "Block":
  162. raise NotImplementedError
  163. def zip(self, other: "Block") -> "Block":
  164. acc = BlockAccessor.for_block(other)
  165. if not isinstance(acc, type(self)):
  166. if isinstance(self, TableBlockAccessor) and isinstance(
  167. acc, TableBlockAccessor
  168. ):
  169. # If block types are different, but still both of TableBlock type, try
  170. # converting both to default block type before zipping.
  171. self_norm, other_norm = TableBlockAccessor.normalize_block_types(
  172. [self._table, other],
  173. )
  174. return BlockAccessor.for_block(self_norm).zip(other_norm)
  175. else:
  176. raise ValueError(
  177. "Cannot zip {} with block of type {}".format(
  178. type(self), type(other)
  179. )
  180. )
  181. if acc.num_rows() != self.num_rows():
  182. raise ValueError(
  183. "Cannot zip self (length {}) with block of length {}".format(
  184. self.num_rows(), acc.num_rows()
  185. )
  186. )
  187. return self._zip(acc)
  188. @staticmethod
  189. def _empty_table() -> Any:
  190. raise NotImplementedError
  191. def _sample(self, n_samples: int, sort_key: "SortKey") -> Any:
  192. raise NotImplementedError
  193. def sample(self, n_samples: int, sort_key: "SortKey") -> Any:
  194. if sort_key is None or callable(sort_key):
  195. raise NotImplementedError(
  196. f"Table sort key must be a column name, was: {sort_key}"
  197. )
  198. if self.num_rows() == 0:
  199. # If the pyarrow table is empty we may not have schema
  200. # so calling table.select() will raise an error.
  201. return self._empty_table()
  202. k = min(n_samples, self.num_rows())
  203. return self._sample(k, sort_key)
  204. def count(self, on: str, ignore_nulls: bool = False) -> Optional[U]:
  205. accessor = BlockColumnAccessor.for_column(self._table[on])
  206. return accessor.count(ignore_nulls=ignore_nulls)
  207. def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
  208. self._validate_column(on)
  209. accessor = BlockColumnAccessor.for_column(self._table[on])
  210. return accessor.sum(ignore_nulls=ignore_nulls)
  211. def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
  212. self._validate_column(on)
  213. accessor = BlockColumnAccessor.for_column(self._table[on])
  214. return accessor.min(ignore_nulls=ignore_nulls)
  215. def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
  216. self._validate_column(on)
  217. accessor = BlockColumnAccessor.for_column(self._table[on])
  218. return accessor.max(ignore_nulls=ignore_nulls)
  219. def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
  220. self._validate_column(on)
  221. accessor = BlockColumnAccessor.for_column(self._table[on])
  222. return accessor.mean(ignore_nulls=ignore_nulls)
  223. def sum_of_squared_diffs_from_mean(
  224. self,
  225. on: str,
  226. ignore_nulls: bool,
  227. mean: Optional[U] = None,
  228. ) -> Optional[U]:
  229. self._validate_column(on)
  230. accessor = BlockColumnAccessor.for_column(self._table[on])
  231. return accessor.sum_of_squared_diffs_from_mean(ignore_nulls=ignore_nulls)
  232. def _validate_column(self, col: str):
  233. if col is None:
  234. raise ValueError(f"Provided `on` value has to be non-null (got '{col}')")
  235. elif col not in self.column_names():
  236. raise ValueError(
  237. f"Referencing column '{col}' not present in the schema: {self.schema()}"
  238. )
  239. def _aggregate(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
  240. """Applies provided aggregations to groups of rows with the same key.
  241. This assumes the block is already sorted by key in ascending order.
  242. Args:
  243. sort_key: A column name or list of column names.
  244. If this is ``None``, place all rows in a single group.
  245. aggs: The aggregations to do.
  246. Returns:
  247. A sorted block of [k, v_1, ..., v_n] columns where k is the groupby
  248. key and v_i is the partially combined accumulator for the ith given
  249. aggregation.
  250. If key is None then the k column is omitted.
  251. """
  252. keys: List[str] = sort_key.get_columns()
  253. def iter_groups() -> Iterator[Tuple[Sequence[KeyType], Block]]:
  254. """Creates an iterator over zero-copy group views."""
  255. if not keys:
  256. # Global aggregation consists of a single "group", so we short-circuit.
  257. yield tuple(), self.to_block()
  258. return
  259. start = end = 0
  260. iter = self.iter_rows(public_row_format=False)
  261. next_row = None
  262. while True:
  263. try:
  264. if next_row is None:
  265. next_row = next(iter)
  266. next_keys = next_row[keys]
  267. while keys_equal(next_row[keys], next_keys):
  268. end += 1
  269. try:
  270. next_row = next(iter)
  271. except StopIteration:
  272. next_row = None
  273. break
  274. yield next_keys, self.slice(start, end)
  275. start = end
  276. except StopIteration:
  277. break
  278. builder = self.builder()
  279. for group_keys, group_view in iter_groups():
  280. # Aggregate.
  281. init_vals = group_keys
  282. if len(group_keys) == 1:
  283. init_vals = group_keys[0]
  284. accumulators = [agg.init(init_vals) for agg in aggs]
  285. for i in range(len(aggs)):
  286. accessor = BlockAccessor.for_block(group_view)
  287. # Skip empty blocks
  288. if accessor.num_rows() > 0:
  289. accumulators[i] = aggs[i].accumulate_block(
  290. accumulators[i], group_view
  291. )
  292. # Build the row.
  293. row = {}
  294. if keys:
  295. for k, gk in zip(keys, group_keys):
  296. row[k] = gk
  297. count = collections.defaultdict(int)
  298. for agg, accumulator in zip(aggs, accumulators):
  299. name = agg.name
  300. # Check for conflicts with existing aggregation name.
  301. if count[name] > 0:
  302. name = self._munge_conflict(name, count[name])
  303. count[name] += 1
  304. row[name] = accumulator
  305. builder.add(row)
  306. return builder.build()
  307. @classmethod
  308. def _combine_aggregated_blocks(
  309. cls,
  310. blocks: List[Block],
  311. sort_key: "SortKey",
  312. aggs: Tuple["AggregateFn"],
  313. finalize: bool = True,
  314. ) -> Tuple[Block, "BlockMetadataWithSchema"]:
  315. """Combine previously aggregated blocks.
  316. This assumes blocks are already sorted by key in ascending order,
  317. so we can do merge sort to get all the rows with the same key.
  318. Args:
  319. blocks: A list of partially combined and sorted blocks.
  320. sort_key: The column name of key or None for global aggregation.
  321. aggs: The aggregations to do.
  322. finalize: Whether to finalize the aggregation. This is used as an
  323. optimization for cases where we repeatedly combine partially
  324. aggregated groups.
  325. Returns:
  326. A block of [k, v_1, ..., v_n] columns and its metadata where k is
  327. the groupby key and v_i is the corresponding aggregation result for
  328. the ith given aggregation.
  329. If key is None then the k column is omitted.
  330. """
  331. # Handle blocks of different types.
  332. blocks = TableBlockAccessor.normalize_block_types(blocks)
  333. stats = BlockExecStats.builder()
  334. keys = sort_key.get_columns()
  335. def _key_fn(r):
  336. if keys:
  337. return tuple(r[keys])
  338. else:
  339. return (0,)
  340. # Replace `None`s and `np.nan` with NULL_SENTINEL to make sure
  341. # we can order the elements (both of these are incomparable)
  342. def safe_key_fn(r):
  343. values = _key_fn(r)
  344. return tuple(
  345. [NULL_SENTINEL if v is None or is_nan(v) else v for v in values]
  346. )
  347. iter = heapq.merge(
  348. *[
  349. BlockAccessor.for_block(block).iter_rows(public_row_format=False)
  350. for block in blocks
  351. ],
  352. key=safe_key_fn,
  353. )
  354. next_row = None
  355. builder = BlockAccessor.for_block(blocks[0]).builder()
  356. while True:
  357. try:
  358. if next_row is None:
  359. next_row = next(iter)
  360. next_keys = _key_fn(next_row)
  361. next_key_columns = keys
  362. def gen():
  363. nonlocal iter
  364. nonlocal next_row
  365. while keys_equal(_key_fn(next_row), next_keys):
  366. yield next_row
  367. try:
  368. next_row = next(iter)
  369. except StopIteration:
  370. next_row = None
  371. break
  372. # Merge.
  373. first = True
  374. accumulators = [None] * len(aggs)
  375. resolved_agg_names = [None] * len(aggs)
  376. for r in gen():
  377. if first:
  378. count = collections.defaultdict(int)
  379. for i in range(len(aggs)):
  380. name = aggs[i].name
  381. # Check for conflicts with existing aggregation
  382. # name.
  383. if count[name] > 0:
  384. name = TableBlockAccessor._munge_conflict(
  385. name, count[name]
  386. )
  387. count[name] += 1
  388. resolved_agg_names[i] = name
  389. accumulators[i] = r[name]
  390. first = False
  391. else:
  392. for i in range(len(aggs)):
  393. accumulators[i] = aggs[i].merge(
  394. accumulators[i], r[resolved_agg_names[i]]
  395. )
  396. # Build the row.
  397. row = {}
  398. if keys:
  399. for col_name, next_key in zip(next_key_columns, next_keys):
  400. row[col_name] = next_key
  401. for agg, agg_name, accumulator in zip(
  402. aggs, resolved_agg_names, accumulators
  403. ):
  404. if finalize:
  405. row[agg_name] = agg.finalize(accumulator)
  406. else:
  407. row[agg_name] = accumulator
  408. builder.add(row)
  409. except StopIteration:
  410. break
  411. ret = builder.build()
  412. return ret, BlockMetadataWithSchema.from_block(ret, stats=stats.build())
  413. def _find_partitions_sorted(
  414. self,
  415. boundaries: List[Tuple[Any]],
  416. sort_key: "SortKey",
  417. ):
  418. partitions = []
  419. # For each boundary value, count the number of items that are less
  420. # than it. Since the block is sorted, these counts partition the items
  421. # such that boundaries[i] <= x < boundaries[i + 1] for each x in
  422. # partition[i]. If `descending` is true, `boundaries` would also be
  423. # in descending order and we only need to count the number of items
  424. # *greater than* the boundary value instead.
  425. bounds = [
  426. find_partition_index(self._table, boundary, sort_key)
  427. for boundary in boundaries
  428. ]
  429. last_idx = 0
  430. for idx in bounds:
  431. partitions.append(self._table[last_idx:idx])
  432. last_idx = idx
  433. partitions.append(self._table[last_idx:])
  434. return partitions
  435. @classmethod
  436. def normalize_block_types(
  437. cls,
  438. blocks: List[Block],
  439. target_block_type: Optional[BlockType] = None,
  440. ) -> List[Block]:
  441. """Normalize input blocks to the specified `normalize_type`. If the blocks
  442. are already all of the same type, returns original blocks.
  443. Args:
  444. blocks: A list of TableBlocks to be normalized.
  445. target_block_type: The type to normalize the blocks to. If None,
  446. Ray Data chooses a type to minimize the amount of data conversions.
  447. Returns:
  448. A list of blocks of the same type.
  449. """
  450. seen_types: Dict[BlockType, int] = collections.defaultdict(int)
  451. for block in blocks:
  452. block_accessor = BlockAccessor.for_block(block)
  453. if not isinstance(block_accessor, TableBlockAccessor):
  454. raise ValueError(
  455. "Block type normalization is only supported for TableBlock, "
  456. f"but received block of type: {type(block)}."
  457. )
  458. seen_types[block_accessor.block_type()] += 1
  459. # If there's just 1 block-type and it's matching target-type, short-circuit
  460. if len(seen_types) == 1 and (
  461. target_block_type is None or [target_block_type] == list(seen_types.keys())
  462. ):
  463. return blocks
  464. # Pick the most prevalent block-type
  465. if target_block_type is None:
  466. _, target_block_type = sorted(
  467. seen_types.items(),
  468. key=lambda x: x[1],
  469. reverse=True,
  470. )[0]
  471. results = [
  472. cls.try_convert_block_type(block, target_block_type) for block in blocks
  473. ]
  474. if any(not isinstance(block, type(results[0])) for block in results):
  475. raise ValueError(
  476. "Expected all blocks to be of the same type after normalization, but "
  477. f"got different types: {[type(b) for b in results]}. "
  478. "Try using blocks of the same type to avoid the issue "
  479. "with block normalization."
  480. )
  481. return results
  482. @classmethod
  483. def try_convert_block_type(cls, block: Block, block_type: BlockType):
  484. if block_type == BlockType.ARROW:
  485. return BlockAccessor.for_block(block).to_arrow()
  486. elif block_type == BlockType.PANDAS:
  487. return BlockAccessor.for_block(block).to_pandas()
  488. else:
  489. return BlockAccessor.for_block(block).to_default()
  490. def hstack(self, other_block: Block) -> Block:
  491. """Combine this table with another table horizontally (column-wise).
  492. This will append the columns.
  493. Args:
  494. other_block: The table to hstack side-by-side with.
  495. Returns:
  496. A new table with columns from both tables combined.
  497. """
  498. raise NotImplementedError