| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- import hashlib
- from collections import deque
- from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
- import ray
- from ray.air.util.data_batch_conversion import BatchFormat
- from ray.data.aggregate import AggregateFnV2
- from ray.util.annotations import DeveloperAPI
- if TYPE_CHECKING:
- from ray.data.dataset import Dataset
- @DeveloperAPI
- def simple_split_tokenizer(value: str) -> List[str]:
- """Tokenize a string using a split on spaces."""
- return value.split(" ")
- @DeveloperAPI
- def simple_hash(value: object, num_features: int) -> int:
- """Deterministically hash a value into the integer space."""
- encoded_value = str(value).encode()
- hashed_value = hashlib.sha256(encoded_value)
- hashed_value_int = int(hashed_value.hexdigest(), 16)
- return hashed_value_int % num_features
- class BaseStatSpec:
- """Encapsulates a statistical computation with optional post-processing."""
- def __init__(
- self,
- *,
- stat_fn: Union[AggregateFnV2, Callable],
- post_process_fn: Callable = lambda x: x,
- ):
- self.stat_fn = stat_fn
- self.post_process_fn = post_process_fn
- class AggregateStatSpec(BaseStatSpec):
- """Represents an AggregateFnV2 spec for a single column."""
- def __init__(
- self,
- *,
- aggregator_fn: Union[AggregateFnV2, Callable[[str], AggregateFnV2]],
- post_process_fn: Callable = lambda x: x,
- column: Optional[str] = None,
- batch_format: Optional[BatchFormat] = None,
- ):
- super().__init__(
- stat_fn=aggregator_fn,
- post_process_fn=post_process_fn,
- )
- self.column = column
- self.batch_format = batch_format
- class CallableStatSpec(BaseStatSpec):
- """Represents a user-defined stat function that operates outside Dataset.aggregate."""
- def __init__(
- self,
- *,
- stat_fn: Callable,
- stat_key_fn: Optional[Callable[[str], str]],
- post_key_fn: Optional[Callable[[str], str]],
- post_process_fn: Callable = lambda x: x,
- columns: List[str],
- ):
- super().__init__(
- stat_fn=stat_fn,
- post_process_fn=post_process_fn,
- )
- self.columns = columns
- self.stat_key_fn = stat_key_fn
- self.post_key_fn = post_key_fn
- class StatComputationPlan:
- """
- Encapsulates a set of aggregators (AggregateFnV2) and legacy stat functions
- to compute statistics over a Ray dataset.
- Supports two types of aggregations:
- 1. AggregateFnV2-based aggregators, which are batch-executed using `Dataset.aggregate(...)`.
- 2. Callable-based stat functions, executed sequentially (legacy use case).
- """
- def __init__(self):
- self._aggregators: Deque[BaseStatSpec] = deque()
- def reset(self):
- self._aggregators.clear()
- def add_aggregator(
- self,
- *,
- aggregator_fn: Callable[[str], AggregateFnV2],
- post_process_fn: Callable = lambda x: x,
- columns: List[str],
- batch_format: Optional[BatchFormat] = None,
- ) -> None:
- """
- Registers an AggregateFnV2 factory for one or more columns.
- Args:
- aggregator_fn: A callable (typically a lambda or class) that accepts a column name and returns an instance of AggregateFnV2.
- The aggregator should set its name using alias_name parameter to control the output key.
- post_process_fn: Function to post-process the aggregated result.
- columns: List of column names to aggregate.
- batch_format: The batch format for aggregation results. If ARROW, results
- are kept in Arrow format for post_process_fn. Otherwise,
- results are converted to Python/pandas format.
- """
- for column in columns:
- agg_instance = aggregator_fn(column)
- self._aggregators.append(
- AggregateStatSpec(
- aggregator_fn=agg_instance,
- post_process_fn=post_process_fn,
- column=column,
- batch_format=batch_format,
- )
- )
- def add_callable_stat(
- self,
- *,
- stat_fn: Callable[[], Any],
- stat_key_fn: Callable[[str], str],
- post_key_fn: Optional[Callable[[str], str]] = None,
- post_process_fn: Callable = lambda x: x,
- columns: List[str],
- ) -> None:
- """
- Registers a custom stat function to be run sequentially.
- This supports legacy use cases where arbitrary callables are needed
- and cannot be run via Dataset.aggregate().
- Args:
- stat_fn: A zero-argument callable that returns the stat.
- stat_key_fn: A callable that takes a column name and returns the key for the stat.
- post_key_fn: Optional; a callable to post-process the key. If not provided, stat_key_fn is used.
- post_process_fn: Function to post-process the result.
- columns: List of column names to compute the stat for.
- """
- self._aggregators.append(
- CallableStatSpec(
- stat_fn=stat_fn,
- post_process_fn=post_process_fn,
- columns=columns,
- stat_key_fn=stat_key_fn,
- post_key_fn=post_key_fn or stat_key_fn,
- )
- )
- def compute(self, dataset: "Dataset") -> Dict[str, Any]:
- """
- Executes all registered aggregators and stat functions.
- AggregateFnV2-based aggregators are batched and executed via Dataset.aggregate().
- Callable-based stat functions are run sequentially.
- Args:
- dataset: The Ray Dataset to compute statistics on.
- Returns:
- A dictionary of computed statistics.
- """
- stats = {}
- # Run batched aggregators (AggregateFnV2)
- aggregators = self._get_aggregate_fn_list()
- if aggregators:
- agg_ds = dataset.groupby(None).aggregate(*aggregators)
- arrow_refs = agg_ds.to_arrow_refs()
- if not arrow_refs:
- raise ValueError("Aggregation returned no results")
- arrow_table = ray.get(arrow_refs[0])
- for spec in self._get_aggregate_specs():
- stat_key = spec.stat_fn.name
- # Aggregation returns single row - extract the scalar value
- # ChunkedArray[0] handles multi-chunk arrays automatically
- agg_result = arrow_table.column(stat_key)[0]
- # Convert to appropriate format based on batch_format
- if spec.batch_format == BatchFormat.ARROW:
- # Pass Arrow scalar (e.g., ListScalar) for Arrow-optimized post-processing
- stats[stat_key] = spec.post_process_fn(agg_result)
- else:
- # Convert to Python for pandas-style post-processing
- stats[stat_key] = spec.post_process_fn(agg_result.as_py())
- # Run sequential stat functions
- for spec in self._get_custom_stat_fn_specs():
- result = spec.stat_fn(spec.stat_key_fn)
- for col in spec.columns:
- stat_key = spec.stat_key_fn(col)
- post_key = spec.post_key_fn(col)
- stats[post_key] = spec.post_process_fn(result[stat_key])
- return stats
- def _get_aggregate_fn_list(self) -> List[AggregateFnV2]:
- return [
- spec.stat_fn
- for spec in self._aggregators
- if isinstance(spec, AggregateStatSpec)
- ]
- def _get_aggregate_specs(self) -> List[AggregateStatSpec]:
- return [
- spec for spec in self._aggregators if isinstance(spec, AggregateStatSpec)
- ]
- def _get_custom_stat_fn_specs(self) -> List[CallableStatSpec]:
- return [
- spec for spec in self._aggregators if isinstance(spec, CallableStatSpec)
- ]
- def has_custom_stat_fn(self):
- return len(self._get_custom_stat_fn_specs()) > 0
- def __iter__(self):
- """
- Iterates over all AggregatorSpecs.
- """
- return iter(self._get_aggregate_specs())
- def make_post_processor(base_fn, callbacks: List[Callable]):
- """
- Wraps a base post-processing function with a sequence of callback functions.
- Useful when multiple post-processing steps need to be applied in order.
- """
- def wrapper(result):
- processed = base_fn(result)
- for cb in callbacks:
- processed = cb(processed)
- return processed
- return wrapper
|