stats.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. import logging
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
  4. import pandas as pd
  5. import pyarrow as pa
  6. from ray.data._internal.tensor_extensions.arrow import convert_to_pyarrow_array
  7. from ray.data.aggregate import (
  8. AggregateFnV2,
  9. ApproximateQuantile,
  10. ApproximateTopK,
  11. Count,
  12. Max,
  13. Mean,
  14. Min,
  15. MissingValuePercentage,
  16. Std,
  17. ZeroPercentage,
  18. )
  19. from ray.util.annotations import PublicAPI
  20. if TYPE_CHECKING:
  21. from ray.data.dataset import Schema
  22. from ray.data.datatype import DataType, TypeCategory
  23. logger = logging.getLogger(__name__)
  24. @PublicAPI(stability="alpha")
  25. @dataclass
  26. class DatasetSummary:
  27. """Wrapper for dataset summary statistics.
  28. Provides methods to access computed statistics.
  29. Attributes:
  30. dataset_schema: PyArrow schema of the original dataset
  31. """
  32. STATISTIC_COLUMN = "statistic"
  33. # PyArrow requires tables whereby each column's value conforms to the column's dtype as defined by the schema.
  34. # However, aggregation results might produce statistics with types different from
  35. # the original column (e.g., 'count' is int64 even for string columns).
  36. # To handle this, we split statistics into two tables:
  37. # 1. _stats_matching_column_dtype: Statistics that share the same type as the
  38. # original column (e.g., min/max for numerical columns). These preserve
  39. # the original column's dtype.
  40. # 2. _stats_mismatching_column_dtype: Statistics with different types (e.g., count,
  41. # missing_pct). These use inferred types (e.g., float64 for count).
  42. _stats_matching_column_dtype: pa.Table
  43. _stats_mismatching_column_dtype: pa.Table
  44. dataset_schema: pa.Schema
  45. columns: list[str]
  46. def _safe_convert_table(self, table: pa.Table):
  47. """Safely convert a PyArrow table to pandas, handling problematic extension types.
  48. Args:
  49. table: PyArrow table to convert
  50. Returns:
  51. pandas DataFrame with converted data
  52. """
  53. from ray.data.block import BlockAccessor
  54. try:
  55. return BlockAccessor.for_block(table).to_pandas()
  56. except (TypeError, ValueError, pa.ArrowInvalid) as e:
  57. logger.warning(
  58. f"Direct conversion to pandas failed ({e}), "
  59. "attempting column-by-column conversion"
  60. )
  61. result_data = {}
  62. for col_name in table.schema.names:
  63. col = table.column(col_name)
  64. try:
  65. result_data[col_name] = col.to_pandas()
  66. except (TypeError, ValueError, pa.ArrowInvalid):
  67. # Cast problematic columns to null type
  68. null_col = pa.nulls(len(col), type=pa.null())
  69. result_data[col_name] = null_col.to_pandas()
  70. return pd.DataFrame(result_data)
  71. def _set_statistic_index(self, df: pd.DataFrame) -> pd.DataFrame:
  72. """Set the statistic column as index if it exists, else return empty DataFrame.
  73. Args:
  74. df: DataFrame to set index on
  75. Returns:
  76. DataFrame with statistic column as index, or empty DataFrame if column missing
  77. """
  78. if self.STATISTIC_COLUMN in df.columns:
  79. return df.set_index(self.STATISTIC_COLUMN)
  80. return pd.DataFrame()
  81. def to_pandas(self):
  82. """Convert summary to a single pandas DataFrame.
  83. Combines statistics from both schema-matching and schema-changing tables.
  84. Note: Some PyArrow extension types (like TensorExtensionType) may fail to convert
  85. to pandas when all values in a column are None. In such cases, this method
  86. attempts to convert column-by-column, casting problematic columns to null type.
  87. Returns:
  88. DataFrame with all statistics, where rows are unique statistics from both tables
  89. """
  90. df_matching = self._set_statistic_index(
  91. self._safe_convert_table(self._stats_matching_column_dtype)
  92. )
  93. df_changing = self._set_statistic_index(
  94. self._safe_convert_table(self._stats_mismatching_column_dtype)
  95. )
  96. # Handle case where both are empty
  97. if df_matching.empty and df_changing.empty:
  98. return pd.DataFrame(columns=[self.STATISTIC_COLUMN])
  99. # Combine tables: prefer schema_matching values, fill with schema_changing
  100. result = df_matching.combine_first(df_changing)
  101. return (
  102. result.reset_index()
  103. .sort_values(self.STATISTIC_COLUMN)
  104. .reset_index(drop=True)
  105. )
  106. def _extract_column_from_table(
  107. self, table: pa.Table, column: str
  108. ) -> Optional[dict]:
  109. """Extract a column from a PyArrow table if it exists.
  110. Args:
  111. table: PyArrow table to extract from
  112. column: Column name to extract
  113. Returns:
  114. DataFrame with 'statistic' and 'value' columns, or None if column doesn't exist
  115. """
  116. if column not in table.schema.names:
  117. return None
  118. df = self._safe_convert_table(table)[[self.STATISTIC_COLUMN, column]]
  119. return df.rename(columns={column: "value"})
  120. def get_column_stats(self, column: str):
  121. """Get all statistics for a specific column, merging from both tables.
  122. Args:
  123. column: Column name to get statistics for
  124. Returns:
  125. DataFrame with all statistics for the column
  126. """
  127. dfs = [
  128. df
  129. for table in [
  130. self._stats_matching_column_dtype,
  131. self._stats_mismatching_column_dtype,
  132. ]
  133. if (df := self._extract_column_from_table(table, column)) is not None
  134. ]
  135. if not dfs:
  136. raise ValueError(f"Column '{column}' not found in summary tables")
  137. # Concatenate and merge duplicate statistics (prefer non-null values)
  138. combined = pd.concat(dfs, ignore_index=True)
  139. # Group by statistic and take first non-null value for each group
  140. def first_non_null(series):
  141. non_null = series.dropna()
  142. return non_null.iloc[0] if len(non_null) > 0 else None
  143. result = (
  144. combined.groupby(self.STATISTIC_COLUMN, sort=False)["value"]
  145. .apply(first_non_null)
  146. .reset_index()
  147. .sort_values(self.STATISTIC_COLUMN)
  148. .reset_index(drop=True)
  149. )
  150. return result
  151. @dataclass
  152. class _DtypeAggregators:
  153. """Container for columns and their aggregators.
  154. Attributes:
  155. column_to_dtype: Mapping from column name to dtype string representation
  156. aggregators: List of all aggregators to apply
  157. """
  158. column_to_dtype: Dict[str, str]
  159. aggregators: List[AggregateFnV2]
  160. def _numerical_aggregators(column: str) -> List[AggregateFnV2]:
  161. """Generate default metrics for numerical columns.
  162. This function returns a list of aggregators that compute the following metrics:
  163. - count
  164. - mean
  165. - min
  166. - max
  167. - std
  168. - approximate_quantile (median)
  169. - missing_value_percentage
  170. - zero_percentage
  171. Args:
  172. column: The name of the numerical column to compute metrics for.
  173. Returns:
  174. A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
  175. """
  176. return [
  177. Count(on=column, ignore_nulls=False),
  178. Mean(on=column, ignore_nulls=True),
  179. Min(on=column, ignore_nulls=True),
  180. Max(on=column, ignore_nulls=True),
  181. Std(on=column, ignore_nulls=True, ddof=0),
  182. ApproximateQuantile(on=column, quantiles=[0.5]),
  183. MissingValuePercentage(on=column),
  184. ZeroPercentage(on=column, ignore_nulls=True),
  185. ]
  186. def _temporal_aggregators(column: str) -> List[AggregateFnV2]:
  187. """Generate default metrics for temporal columns.
  188. This function returns a list of aggregators that compute the following metrics:
  189. - count
  190. - min
  191. - max
  192. - missing_value_percentage
  193. Args:
  194. column: The name of the temporal column to compute metrics for.
  195. Returns:
  196. A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
  197. """
  198. return [
  199. Count(on=column, ignore_nulls=False),
  200. Min(on=column, ignore_nulls=True),
  201. Max(on=column, ignore_nulls=True),
  202. MissingValuePercentage(on=column),
  203. ]
  204. def _basic_aggregators(column: str) -> List[AggregateFnV2]:
  205. """Generate default metrics for all columns.
  206. This function returns a list of aggregators that compute the following metrics:
  207. - count
  208. - missing_value_percentage
  209. - approximate_top_k (top 10 most frequent values)
  210. Args:
  211. column: The name of the column to compute metrics for.
  212. Returns:
  213. A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
  214. """
  215. return [
  216. Count(on=column, ignore_nulls=False),
  217. MissingValuePercentage(on=column),
  218. ApproximateTopK(on=column, k=10),
  219. ]
  220. def _default_dtype_aggregators() -> Dict[
  221. Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]]
  222. ]:
  223. """Get default mapping from Ray Data DataType to aggregator factory functions.
  224. This function returns factory functions that create aggregators for specific columns.
  225. Returns:
  226. Dict mapping DataType or TypeCategory to factory functions that take a column name
  227. and return a list of aggregators for that column.
  228. Examples:
  229. >>> from ray.data.datatype import DataType
  230. >>> from ray.data.stats import _default_dtype_aggregators
  231. >>> mapping = _default_dtype_aggregators()
  232. >>> factory = mapping.get(DataType.int32())
  233. >>> aggs = factory("my_column") # Creates aggregators for "my_column"
  234. """
  235. from ray.data.datatype import DataType, TypeCategory
  236. # Use pattern-matching types for cleaner mapping
  237. return {
  238. # Numerical types
  239. DataType.int8(): _numerical_aggregators,
  240. DataType.int16(): _numerical_aggregators,
  241. DataType.int32(): _numerical_aggregators,
  242. DataType.int64(): _numerical_aggregators,
  243. DataType.uint8(): _numerical_aggregators,
  244. DataType.uint16(): _numerical_aggregators,
  245. DataType.uint32(): _numerical_aggregators,
  246. DataType.uint64(): _numerical_aggregators,
  247. DataType.float32(): _numerical_aggregators,
  248. DataType.float64(): _numerical_aggregators,
  249. DataType.bool(): _numerical_aggregators,
  250. # String and binary types
  251. DataType.string(): _basic_aggregators,
  252. DataType.binary(): _basic_aggregators,
  253. # Temporal types - pattern matches all temporal types (timestamp, date, time, duration)
  254. TypeCategory.TEMPORAL: _temporal_aggregators,
  255. # Note: Complex types like lists, structs, maps use fallback logic
  256. # in _get_aggregators_for_dtype since they can't be easily enumerated
  257. }
  258. def _get_fallback_aggregators(column: str, dtype: "DataType") -> List[AggregateFnV2]:
  259. """Get aggregators using heuristic-based type detection.
  260. This is a fallback when no explicit mapping is found for the dtype.
  261. Args:
  262. column: Column name
  263. dtype: Ray Data DataType for the column
  264. Returns:
  265. List of aggregators suitable for the column type
  266. """
  267. try:
  268. # Check for null type first
  269. if dtype.is_arrow_type() and pa.types.is_null(dtype._physical_dtype):
  270. return [Count(on=column, ignore_nulls=False)]
  271. elif dtype.is_numerical_type():
  272. return _numerical_aggregators(column)
  273. elif dtype.is_temporal_type():
  274. return _temporal_aggregators(column)
  275. else:
  276. # Default for strings, binary, lists, nested types, etc.
  277. return _basic_aggregators(column)
  278. except Exception as e:
  279. logger.warning(
  280. f"Could not determine aggregators for column '{column}' with dtype {dtype}: {e}. "
  281. f"Using basic aggregators."
  282. )
  283. return _basic_aggregators(column)
  284. def _get_aggregators_for_dtype(
  285. column: str,
  286. dtype: "DataType",
  287. dtype_agg_mapping: Dict[
  288. Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]]
  289. ],
  290. ) -> List[AggregateFnV2]:
  291. """Get aggregators for a specific column based on its DataType.
  292. Attempts to match the dtype against the provided mapping first, then
  293. falls back to heuristic-based selection if no match is found.
  294. Args:
  295. column: Column name
  296. dtype: Ray Data DataType for the column
  297. dtype_agg_mapping: Mapping from DataType to factory functions
  298. Returns:
  299. List of aggregators with the column name properly set
  300. """
  301. from ray.data.datatype import DataType, TypeCategory
  302. # Try to find a match in the mapping
  303. for mapping_key, factory in dtype_agg_mapping.items():
  304. if isinstance(mapping_key, DataType) and dtype == mapping_key:
  305. return factory(column)
  306. elif isinstance(mapping_key, (TypeCategory, str)) and dtype.is_of(mapping_key):
  307. return factory(column)
  308. # Fallback: Use heuristic-based selection
  309. return _get_fallback_aggregators(column, dtype)
  310. def _dtype_aggregators_for_dataset(
  311. schema: Optional["Schema"],
  312. columns: Optional[List[str]] = None,
  313. dtype_agg_mapping: Optional[
  314. Dict[Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]]]
  315. ] = None,
  316. ) -> _DtypeAggregators:
  317. """Generate aggregators for columns in a dataset based on their DataTypes.
  318. Args:
  319. schema: A Ray Schema instance
  320. columns: List of columns to include. If None, all columns will be included.
  321. dtype_agg_mapping: Optional user-provided mapping from DataType to aggregator factories.
  322. Each value should be a callable that takes a column name and returns aggregators.
  323. This will be merged with the default mapping (user mapping takes precedence).
  324. Returns:
  325. _DtypeAggregators containing column-to-dtype mapping and aggregators
  326. Raises:
  327. ValueError: If schema is None or if specified columns don't exist in schema
  328. """
  329. from ray.data.datatype import DataType
  330. if not schema:
  331. raise ValueError("Dataset must have a schema to determine column types")
  332. if columns is None:
  333. columns = schema.names
  334. # Validate columns exist in schema
  335. missing_cols = set(columns) - set(schema.names)
  336. if missing_cols:
  337. raise ValueError(f"Columns {missing_cols} not found in dataset schema")
  338. # Build final mapping: default + user overrides
  339. defaults = _default_dtype_aggregators()
  340. if dtype_agg_mapping:
  341. # Put user overrides first so they are checked before default patterns
  342. final_mapping = dtype_agg_mapping.copy()
  343. for k, v in defaults.items():
  344. if k not in final_mapping:
  345. final_mapping[k] = v
  346. else:
  347. final_mapping = defaults
  348. # Generate aggregators for each column
  349. column_to_dtype = {}
  350. all_aggs = []
  351. name_to_type = dict(zip(schema.names, schema.types))
  352. for name in columns:
  353. pa_type = name_to_type[name]
  354. if pa_type is None or pa_type is object:
  355. logger.warning(f"Skipping field '{name}': type is None or unsupported")
  356. continue
  357. ray_dtype = DataType.from_arrow(pa_type)
  358. column_to_dtype[name] = str(ray_dtype)
  359. all_aggs.extend(_get_aggregators_for_dtype(name, ray_dtype, final_mapping))
  360. return _DtypeAggregators(
  361. column_to_dtype=column_to_dtype,
  362. aggregators=all_aggs,
  363. )
  364. def _format_stats(
  365. agg: AggregateFnV2, value: Any, agg_type: pa.DataType
  366. ) -> Dict[str, Tuple[Any, pa.DataType]]:
  367. """Format aggregation result into stat entries.
  368. Takes the raw aggregation result and formats it into one or more stat
  369. entries. For scalar results, returns a single entry. For list results,
  370. expands into multiple indexed entries.
  371. Args:
  372. agg: The aggregator instance
  373. value: The aggregation result value
  374. agg_type: PyArrow type of the aggregation result
  375. Returns:
  376. Dictionary mapping stat names to (value, type) tuples
  377. """
  378. from ray.data.datatype import DataType
  379. agg_name = agg.get_agg_name()
  380. # Handle list results: expand into separate indexed stats
  381. # If the value is None but the type is list, it means we got a null result
  382. # for a list-type aggregator (e.g., ignore_nulls=True and all nulls).
  383. is_list_type = (
  384. pa.types.is_list(agg_type) or DataType.from_arrow(agg_type).is_list_type()
  385. )
  386. if isinstance(value, list) or (value is None and is_list_type):
  387. scalar_type = (
  388. agg_type.value_type
  389. if DataType.from_arrow(agg_type).is_list_type()
  390. else agg_type
  391. )
  392. if value is None:
  393. # Can't expand None without knowing the size, return as-is
  394. pass
  395. else:
  396. labels = [str(idx) for idx in range(len(value))]
  397. return {
  398. f"{agg_name}[{label}]": (list_val, scalar_type)
  399. for label, list_val in zip(labels, value)
  400. }
  401. # Fallback to scalar result for non-list values or unexpandable Nones
  402. return {agg_name: (value, agg_type)}
  403. def _parse_summary_stats(
  404. agg_result: Dict[str, any],
  405. original_schema: pa.Schema,
  406. agg_schema: pa.Schema,
  407. aggregators: List[AggregateFnV2],
  408. ) -> tuple:
  409. """Parse aggregation results into schema-matching and schema-changing stats.
  410. Args:
  411. agg_result: Dictionary of aggregation results with keys like "count(col)"
  412. original_schema: Original dataset schema
  413. agg_schema: Schema of aggregation results
  414. aggregators: List of aggregators used to generate the results
  415. Returns:
  416. Tuple of (schema_matching_stats, schema_changing_stats, column_names)
  417. """
  418. schema_matching = {}
  419. schema_changing = {}
  420. columns = set()
  421. # Build a lookup map from "stat_name(col_name)" to aggregator
  422. agg_lookup = {agg.name: agg for agg in aggregators}
  423. for key, value in agg_result.items():
  424. if "(" not in key or not key.endswith(")"):
  425. continue
  426. # Get aggregator and extract info
  427. agg = agg_lookup.get(key)
  428. if not agg:
  429. continue
  430. col_name = agg.get_target_column()
  431. if not col_name:
  432. # Skip aggregations without a target column (e.g., Count())
  433. continue
  434. # Format the aggregation results
  435. agg_type = agg_schema.field(key).type
  436. original_type = original_schema.field(col_name).type
  437. formatted_stats = _format_stats(agg, value, agg_type)
  438. for stat_name, (stat_value, stat_type) in formatted_stats.items():
  439. # Add formatted stats to appropriate dict based on schema matching
  440. stats_dict = (
  441. schema_matching if stat_type == original_type else schema_changing
  442. )
  443. stats_dict.setdefault(stat_name, {})[col_name] = (stat_value, stat_type)
  444. columns.add(col_name)
  445. return schema_matching, schema_changing, columns
  446. def _create_pyarrow_array(
  447. col_data: List, col_type: Optional[pa.DataType] = None, col_name: str = ""
  448. ) -> pa.Array:
  449. """Create a PyArrow array with fallback strategies.
  450. Uses convert_to_pyarrow_array from arrow_block.py for type inference and
  451. error handling when no specific type is provided.
  452. Args:
  453. col_data: List of column values
  454. col_type: Optional PyArrow type to use
  455. col_name: Column name for error messages (optional)
  456. Returns:
  457. PyArrow array
  458. """
  459. if col_type is not None:
  460. try:
  461. return pa.array(col_data, type=col_type)
  462. except (pa.ArrowTypeError, pa.ArrowInvalid):
  463. # Type mismatch - fall through to type inference
  464. pass
  465. # Use convert_to_pyarrow_array for type inference and error handling
  466. # This handles tensors, extension types, and fallback to ArrowPythonObjectArray
  467. return convert_to_pyarrow_array(col_data, col_name or "column")
  468. def _build_summary_table(
  469. stats_dict: Dict[str, Dict[str, tuple]],
  470. all_columns: set,
  471. original_schema: pa.Schema,
  472. preserve_types: bool,
  473. ) -> pa.Table:
  474. """Build a PyArrow table from parsed statistics.
  475. Args:
  476. stats_dict: Nested dict of {stat_name: {col_name: (value, type)}}
  477. all_columns: Set of all column names across both tables
  478. original_schema: Original dataset schema
  479. preserve_types: If True, use original schema types for columns
  480. Returns:
  481. PyArrow table with statistics
  482. """
  483. if not stats_dict:
  484. return pa.table({})
  485. stat_names = sorted(stats_dict.keys())
  486. table_data = {DatasetSummary.STATISTIC_COLUMN: stat_names}
  487. for col_name in sorted(all_columns):
  488. # Collect values and infer type
  489. col_data = []
  490. first_type = None
  491. for stat_name in stat_names:
  492. if col_name in stats_dict[stat_name]:
  493. value, agg_type = stats_dict[stat_name][col_name]
  494. col_data.append(value)
  495. if first_type is None:
  496. first_type = agg_type
  497. else:
  498. col_data.append(None)
  499. # Determine column type: prefer original schema, then first aggregation type, then infer
  500. if preserve_types and col_name in original_schema.names:
  501. col_type = original_schema.field(col_name).type
  502. else:
  503. col_type = first_type
  504. table_data[col_name] = _create_pyarrow_array(col_data, col_type, col_name)
  505. return pa.table(table_data)