arrow_block.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. import logging
  2. import random
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. Callable,
  7. Dict,
  8. Iterator,
  9. List,
  10. Mapping,
  11. Optional,
  12. Tuple,
  13. TypeVar,
  14. Union,
  15. )
  16. import numpy as np
  17. from packaging.version import parse as parse_version
  18. from ray._private.ray_constants import env_integer
  19. from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
  20. from ray.data._internal.arrow_ops.transform_pyarrow import shuffle
  21. from ray.data._internal.row import row_repr, row_repr_pretty, row_str
  22. from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
  23. from ray.data._internal.tensor_extensions.arrow import (
  24. convert_to_pyarrow_array,
  25. pyarrow_table_from_pydict,
  26. )
  27. from ray.data._internal.utils.arrow_utils import get_pyarrow_version
  28. from ray.data.block import (
  29. Block,
  30. BlockAccessor,
  31. BlockColumn,
  32. BlockColumnAccessor,
  33. BlockExecStats,
  34. BlockMetadataWithSchema,
  35. BlockType,
  36. U,
  37. )
  38. from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE, DataContext
  39. from ray.data.expressions import Expr
  40. try:
  41. import pyarrow
  42. except ImportError:
  43. pyarrow = None
  44. if TYPE_CHECKING:
  45. import pandas
  46. from ray.data._internal.planner.exchange.sort_task_spec import SortKey
  47. T = TypeVar("T")
  48. logger = logging.getLogger(__name__)
  49. _MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY = parse_version("13.0.0")
  50. _BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"
  51. # Set the max chunk size in bytes for Arrow to Batches conversion in
  52. # ArrowBlockAccessor.iter_rows(). Default to 4MB, to optimize for image
  53. # datasets in parquet format.
  54. ARROW_MAX_CHUNK_SIZE_BYTES = env_integer(
  55. "RAY_DATA_ARROW_MAX_CHUNK_SIZE_BYTES",
  56. int(DEFAULT_TARGET_MAX_BLOCK_SIZE / 32),
  57. )
  58. # We offload some transformations to polars for performance.
  59. def get_sort_transform(context: DataContext) -> Callable:
  60. if context.use_polars or context.use_polars_sort:
  61. return transform_polars.sort
  62. else:
  63. return transform_pyarrow.sort
  64. def get_concat_and_sort_transform(context: DataContext) -> Callable:
  65. if context.use_polars or context.use_polars_sort:
  66. return transform_polars.concat_and_sort
  67. else:
  68. return transform_pyarrow.concat_and_sort
  69. class ArrowRow(Mapping):
  70. """
  71. Row of a tabular Dataset backed by a Arrow Table block.
  72. """
  73. def __init__(self, row: Any):
  74. self._row = row
  75. def __getitem__(self, key: Union[str, List[str]]) -> Any:
  76. from ray.data.extensions import get_arrow_extension_tensor_types
  77. tensor_arrow_extension_types = get_arrow_extension_tensor_types()
  78. def get_item(keys: List[str]) -> Any:
  79. schema = self._row.schema
  80. if isinstance(schema.field(keys[0]).type, tensor_arrow_extension_types):
  81. # Build a tensor row.
  82. return tuple(
  83. [
  84. ArrowBlockAccessor._build_tensor_row(
  85. self._row, col_name=key, row_idx=0
  86. )
  87. for key in keys
  88. ]
  89. )
  90. table = self._row.select(keys)
  91. if len(table) == 0:
  92. return None
  93. items = [col[0] for col in table.columns]
  94. try:
  95. # Try to interpret this as a pyarrow.Scalar value.
  96. return tuple([item.as_py() for item in items])
  97. except AttributeError:
  98. # Assume that this row is an element of an extension array, and
  99. # that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
  100. return items
  101. is_single_item = isinstance(key, str)
  102. keys = [key] if is_single_item else key
  103. items = get_item(keys)
  104. if items is None:
  105. return None
  106. elif is_single_item:
  107. return items[0]
  108. else:
  109. return items
  110. def __iter__(self) -> Iterator:
  111. for k in self._row.column_names:
  112. yield k
  113. def __len__(self):
  114. return self._row.num_columns
  115. def as_pydict(self) -> Dict[str, Any]:
  116. return dict(self.items())
  117. def __str__(self):
  118. return row_str(self)
  119. def __repr__(self):
  120. return row_repr(self)
  121. def _repr_pretty_(self, p, cycle):
  122. return row_repr_pretty(self, p, cycle)
  123. class ArrowBlockBuilder(TableBlockBuilder):
  124. def __init__(self):
  125. if pyarrow is None:
  126. raise ImportError("Run `pip install pyarrow` for Arrow support")
  127. super().__init__((pyarrow.Table, bytes))
  128. @staticmethod
  129. def _table_from_pydict(columns: Dict[str, List[Any]]) -> Block:
  130. return pyarrow_table_from_pydict(
  131. {
  132. column_name: convert_to_pyarrow_array(column_values, column_name)
  133. for column_name, column_values in columns.items()
  134. }
  135. )
  136. @staticmethod
  137. def _combine_tables(tables: List[Block]) -> Block:
  138. if len(tables) > 1:
  139. return transform_pyarrow.concat(tables, promote_types=True)
  140. else:
  141. return tables[0]
  142. @staticmethod
  143. def _concat_would_copy() -> bool:
  144. return False
  145. @staticmethod
  146. def _empty_table() -> "pyarrow.Table":
  147. return pyarrow_table_from_pydict({})
  148. def block_type(self) -> BlockType:
  149. return BlockType.ARROW
  150. def _get_max_chunk_size(
  151. table: "pyarrow.Table", max_chunk_size_bytes: int
  152. ) -> Optional[int]:
  153. """
  154. Calculate the max chunk size in rows for Arrow to Batches conversion in
  155. ArrowBlockAccessor.iter_rows().
  156. Args:
  157. table: The pyarrow table to calculate the max chunk size for.
  158. max_chunk_size_bytes: The max chunk size in bytes.
  159. Returns:
  160. The max chunk size in rows, or None if the table is empty.
  161. """
  162. if table.nbytes == 0:
  163. return None
  164. else:
  165. avg_row_size = table.nbytes / table.num_rows
  166. return max(1, int(max_chunk_size_bytes / avg_row_size))
  167. class ArrowBlockAccessor(TableBlockAccessor):
  168. ROW_TYPE = ArrowRow
  169. def __init__(self, table: "pyarrow.Table"):
  170. if pyarrow is None:
  171. raise ImportError("Run `pip install pyarrow` for Arrow support")
  172. super().__init__(table)
  173. self._max_chunk_size: Optional[int] = None
  174. def _get_row(self, index: int) -> ArrowRow:
  175. base_row = self.slice(index, index + 1, copy=False)
  176. return ArrowRow(base_row)
  177. def column_names(self) -> List[str]:
  178. return self._table.column_names
  179. def fill_column(self, name: str, value: Any) -> Block:
  180. import pyarrow.compute as pc
  181. # Check if value is array-like - if so, use upsert_column logic
  182. if isinstance(value, (pyarrow.Array, pyarrow.ChunkedArray)):
  183. return self.upsert_column(name, value)
  184. else:
  185. # Scalar value - use original fill_column logic
  186. if isinstance(value, pyarrow.Scalar):
  187. type = value.type
  188. else:
  189. type = pyarrow.infer_type([value])
  190. array = pyarrow.nulls(len(self._table), type=type)
  191. array = pc.fill_null(array, value)
  192. return self.upsert_column(name, array)
  193. @classmethod
  194. def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
  195. reader = pyarrow.ipc.open_stream(data)
  196. return cls(reader.read_all())
  197. @staticmethod
  198. def _build_tensor_row(row: ArrowRow, row_idx: int, col_name: str) -> np.ndarray:
  199. element = row[col_name][row_idx]
  200. arr = element.as_py()
  201. assert isinstance(arr, np.ndarray), type(arr)
  202. return arr
  203. def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
  204. view = self._table.slice(start, end - start)
  205. if copy:
  206. view = transform_pyarrow.combine_chunks(view, copy)
  207. return view
  208. def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
  209. return shuffle(self._table, random_seed)
  210. def schema(self) -> "pyarrow.lib.Schema":
  211. return self._table.schema
  212. def to_pandas(self) -> "pandas.DataFrame":
  213. from ray.data.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
  214. # We specify ignore_metadata=True because pyarrow will use the metadata
  215. # to build the Table. This is handled incorrectly for older pyarrow versions
  216. ctx = DataContext.get_current()
  217. df = self._table.to_pandas(ignore_metadata=ctx.pandas_block_ignore_metadata)
  218. if ctx.enable_tensor_extension_casting:
  219. df = _cast_tensor_columns_to_ndarrays(df)
  220. return df
  221. def to_numpy(
  222. self, columns: Optional[Union[str, List[str]]] = None
  223. ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
  224. if columns is None:
  225. columns = self._table.column_names
  226. should_be_single_ndarray = False
  227. elif isinstance(columns, list):
  228. should_be_single_ndarray = False
  229. else:
  230. columns = [columns]
  231. should_be_single_ndarray = True
  232. column_names_set = set(self._table.column_names)
  233. for column in columns:
  234. if column not in column_names_set:
  235. raise ValueError(
  236. f"Cannot find column {column}, available columns: "
  237. f"{column_names_set}"
  238. )
  239. column_values_ndarrays = []
  240. for col_name in columns:
  241. col = self._table[col_name]
  242. # Combine columnar values arrays to make these contiguous
  243. # (making them compatible with numpy format)
  244. combined_array = transform_pyarrow.combine_chunked_array(col)
  245. column_values_ndarrays.append(
  246. transform_pyarrow.to_numpy(combined_array, zero_copy_only=False)
  247. )
  248. if should_be_single_ndarray:
  249. assert len(columns) == 1
  250. return column_values_ndarrays[0]
  251. else:
  252. return dict(zip(columns, column_values_ndarrays))
  253. def to_arrow(self) -> "pyarrow.Table":
  254. return self._table
  255. def num_rows(self) -> int:
  256. # Arrow may represent an empty table via an N > 0 row, 0-column table, e.g. when
  257. # slicing an empty table, so we return 0 if num_columns == 0.
  258. return self._table.num_rows if self._table.num_columns > 0 else 0
  259. def size_bytes(self) -> int:
  260. return self._table.nbytes
  261. def _zip(self, acc: BlockAccessor) -> "Block":
  262. r = self.to_arrow()
  263. s = acc.to_arrow()
  264. for col_name in s.column_names:
  265. col = s.column(col_name)
  266. # Ensure the column names are unique after zip.
  267. if col_name in r.column_names:
  268. i = 1
  269. new_name = col_name
  270. while new_name in r.column_names:
  271. new_name = "{}_{}".format(col_name, i)
  272. i += 1
  273. col_name = new_name
  274. r = r.append_column(col_name, col)
  275. return r
  276. def upsert_column(
  277. self, column_name: str, column_data: BlockColumn
  278. ) -> "pyarrow.Table":
  279. assert isinstance(
  280. column_data, (pyarrow.Array, pyarrow.ChunkedArray)
  281. ), f"Expected either a pyarrow.Array or pyarrow.ChunkedArray, got: {type(column_data)}"
  282. column_idx = self._table.schema.get_field_index(column_name)
  283. if column_idx == -1:
  284. return self._table.append_column(column_name, column_data)
  285. else:
  286. return self._table.set_column(column_idx, column_name, column_data)
  287. @staticmethod
  288. def builder() -> ArrowBlockBuilder:
  289. return ArrowBlockBuilder()
  290. @staticmethod
  291. def _empty_table() -> "pyarrow.Table":
  292. return ArrowBlockBuilder._empty_table()
  293. def take(
  294. self,
  295. indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
  296. ) -> "pyarrow.Table":
  297. """Select rows from the underlying table.
  298. This method is an alternative to pyarrow.Table.take(), which breaks for
  299. extension arrays.
  300. """
  301. return transform_pyarrow.take_table(self._table, indices)
  302. def drop(self, columns: List[str]) -> Block:
  303. return self._table.drop(columns)
  304. def select(self, columns: List[str]) -> "pyarrow.Table":
  305. if not all(isinstance(col, str) for col in columns):
  306. raise ValueError(
  307. "Columns must be a list of column name strings when aggregating on "
  308. f"Arrow blocks, but got: {columns}."
  309. )
  310. if len(columns) == 0:
  311. # Applicable for count which does an empty projection.
  312. # Pyarrow returns a table with 0 columns and num_rows rows.
  313. return self.fill_column(_BATCH_SIZE_PRESERVING_STUB_COL_NAME, None)
  314. return self._table.select(columns)
  315. def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table":
  316. return self._table.rename_columns(columns_rename)
  317. def hstack(self, other_block: "pyarrow.Table") -> "pyarrow.Table":
  318. result_table = self._table
  319. for name, column in zip(other_block.column_names, other_block.columns):
  320. result_table = result_table.append_column(name, column)
  321. return result_table
  322. def _sample(self, n_samples: int, sort_key: "SortKey") -> "pyarrow.Table":
  323. indices = random.sample(range(self._table.num_rows), n_samples)
  324. table = self._table.select(sort_key.get_columns())
  325. return transform_pyarrow.take_table(table, indices)
  326. def sort(self, sort_key: "SortKey") -> Block:
  327. assert (
  328. sort_key.get_columns()
  329. ), f"Sorting columns couldn't be empty (got {sort_key.get_columns()})"
  330. if self._table.num_rows == 0:
  331. # If the pyarrow table is empty we may not have schema
  332. # so calling sort_indices() will raise an error.
  333. return self._empty_table()
  334. context = DataContext.get_current()
  335. sort = get_sort_transform(context)
  336. return sort(self._table, sort_key)
  337. def sort_and_partition(
  338. self, boundaries: List[T], sort_key: "SortKey"
  339. ) -> List["Block"]:
  340. table = self.sort(sort_key)
  341. if table.num_rows == 0:
  342. return [self._empty_table() for _ in range(len(boundaries) + 1)]
  343. elif len(boundaries) == 0:
  344. return [table]
  345. return BlockAccessor.for_block(table)._find_partitions_sorted(
  346. boundaries, sort_key
  347. )
  348. @staticmethod
  349. def merge_sorted_blocks(
  350. blocks: List[Block], sort_key: "SortKey"
  351. ) -> Tuple[Block, BlockMetadataWithSchema]:
  352. stats = BlockExecStats.builder()
  353. blocks = [b for b in blocks if b.num_rows > 0]
  354. if len(blocks) == 0:
  355. ret = ArrowBlockAccessor._empty_table()
  356. else:
  357. # Handle blocks of different types.
  358. blocks = TableBlockAccessor.normalize_block_types(blocks, BlockType.ARROW)
  359. concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
  360. ret = concat_and_sort(blocks, sort_key, promote_types=True)
  361. return ret, BlockMetadataWithSchema.from_block(ret, stats=stats.build())
  362. def block_type(self) -> BlockType:
  363. return BlockType.ARROW
  364. def iter_rows(
  365. self, public_row_format: bool
  366. ) -> Iterator[Union[Mapping, np.ndarray]]:
  367. table = self._table
  368. if public_row_format:
  369. if self._max_chunk_size is None:
  370. # Calling _get_max_chunk_size in constructor makes it slow, so we
  371. # are calling it here only when needed.
  372. self._max_chunk_size = _get_max_chunk_size(
  373. table, ARROW_MAX_CHUNK_SIZE_BYTES
  374. )
  375. for batch in table.to_batches(max_chunksize=self._max_chunk_size):
  376. yield from batch.to_pylist()
  377. else:
  378. num_rows = self.num_rows()
  379. for i in range(num_rows):
  380. yield self._get_row(i)
  381. def filter(self, predicate_expr: "Expr") -> "pyarrow.Table":
  382. """Filter rows based on a predicate expression."""
  383. if self._table.num_rows == 0:
  384. return self._table
  385. from ray.data._internal.planner.plan_expression.expression_evaluator import (
  386. eval_expr,
  387. )
  388. # Evaluate the expression to get a boolean mask
  389. mask = eval_expr(predicate_expr, self._table)
  390. # Use PyArrow's built-in filter method
  391. return self._table.filter(mask)
  392. class ArrowBlockColumnAccessor(BlockColumnAccessor):
  393. def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]):
  394. super().__init__(col)
  395. def count(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
  396. import pyarrow.compute as pac
  397. res = pac.count(self._column, mode="only_valid" if ignore_nulls else "all")
  398. return res.as_py() if as_py else res
  399. def sum(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
  400. import pyarrow.compute as pac
  401. res = pac.sum(self._column, skip_nulls=ignore_nulls)
  402. return res.as_py() if as_py else res
  403. def min(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
  404. import pyarrow.compute as pac
  405. res = pac.min(self._column, skip_nulls=ignore_nulls)
  406. return res.as_py() if as_py else res
  407. def max(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
  408. import pyarrow.compute as pac
  409. res = pac.max(self._column, skip_nulls=ignore_nulls)
  410. return res.as_py() if as_py else res
  411. def mean(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
  412. import pyarrow.compute as pac
  413. res = pac.mean(self._column, skip_nulls=ignore_nulls)
  414. return res.as_py() if as_py else res
  415. def sum_of_squared_diffs_from_mean(
  416. self, ignore_nulls: bool, mean: Optional[U] = None, as_py: bool = True
  417. ) -> Optional[U]:
  418. import pyarrow.compute as pac
  419. # Calculate mean if not provided
  420. if mean is None:
  421. mean = self.mean(ignore_nulls=ignore_nulls)
  422. if mean is None:
  423. return None
  424. res = pac.sum(
  425. pac.power(pac.subtract(self._column, mean), 2), skip_nulls=ignore_nulls
  426. )
  427. return res.as_py() if as_py else res
  428. def quantile(
  429. self, *, q: float, ignore_nulls: bool, as_py: bool = True
  430. ) -> Optional[U]:
  431. import pyarrow.compute as pac
  432. array = pac.quantile(self._column, q=q, skip_nulls=ignore_nulls)
  433. # NOTE: That quantile method still returns an array
  434. res = array[0]
  435. return res.as_py() if as_py else res
  436. def unique(self) -> BlockColumn:
  437. import pyarrow.compute as pac
  438. if self.is_composed_of_lists():
  439. # NOTE: Arrow doesn't provide unique kernels for `ListArray`s and
  440. # such, so we rely on Polars to encode and compute unique
  441. # values instead
  442. import polars
  443. return polars.from_arrow(self._column).unique().to_arrow()
  444. return pac.unique(self._column)
  445. def value_counts(self) -> Optional[Dict[str, List]]:
  446. import pyarrow.compute as pac
  447. value_counts: pyarrow.StructArray = pac.value_counts(self._column)
  448. if len(value_counts) == 0:
  449. return None
  450. return {
  451. "values": value_counts.field("values").to_pylist(),
  452. "counts": value_counts.field("counts").to_pylist(),
  453. }
  454. def hash(self) -> BlockColumn:
  455. import polars as pl
  456. df = pl.DataFrame({"col": self._column})
  457. hashes = df.hash_rows().cast(pl.Int64, wrap_numerical=True)
  458. return hashes.to_arrow()
  459. def flatten(self) -> BlockColumn:
  460. import pyarrow.compute as pac
  461. return pac.list_flatten(self._column)
  462. def dropna(self) -> BlockColumn:
  463. import pyarrow.compute as pac
  464. return pac.drop_null(self._column)
  465. def is_composed_of_lists(self) -> bool:
  466. types = (pyarrow.lib.ListType, pyarrow.lib.LargeListType)
  467. return isinstance(self._column.type, types)
  468. def to_pylist(self) -> List[Any]:
  469. return self._column.to_pylist()
  470. def to_numpy(self, zero_copy_only: bool = False) -> np.ndarray:
  471. if get_pyarrow_version() < _MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY:
  472. if isinstance(
  473. self._column, pyarrow.ChunkedArray
  474. ): # NOTE: ChunkedArray in Pyarrow < 13.0.0 does not support ``zero_copy_only``
  475. return self._column.to_numpy()
  476. else:
  477. return self._column.to_numpy(zero_copy_only=zero_copy_only)
  478. return self._column.to_numpy(zero_copy_only=zero_copy_only)
  479. def _as_arrow_compatible(self) -> Union[List[Any], "pyarrow.Array"]:
  480. return self._column