utils.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import hashlib
  2. from collections import deque
  3. from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union
  4. import ray
  5. from ray.air.util.data_batch_conversion import BatchFormat
  6. from ray.data.aggregate import AggregateFnV2
  7. from ray.util.annotations import DeveloperAPI
  8. if TYPE_CHECKING:
  9. from ray.data.dataset import Dataset
  10. @DeveloperAPI
  11. def simple_split_tokenizer(value: str) -> List[str]:
  12. """Tokenize a string using a split on spaces."""
  13. return value.split(" ")
  14. @DeveloperAPI
  15. def simple_hash(value: object, num_features: int) -> int:
  16. """Deterministically hash a value into the integer space."""
  17. encoded_value = str(value).encode()
  18. hashed_value = hashlib.sha256(encoded_value)
  19. hashed_value_int = int(hashed_value.hexdigest(), 16)
  20. return hashed_value_int % num_features
  21. class BaseStatSpec:
  22. """Encapsulates a statistical computation with optional post-processing."""
  23. def __init__(
  24. self,
  25. *,
  26. stat_fn: Union[AggregateFnV2, Callable],
  27. post_process_fn: Callable = lambda x: x,
  28. ):
  29. self.stat_fn = stat_fn
  30. self.post_process_fn = post_process_fn
  31. class AggregateStatSpec(BaseStatSpec):
  32. """Represents an AggregateFnV2 spec for a single column."""
  33. def __init__(
  34. self,
  35. *,
  36. aggregator_fn: Union[AggregateFnV2, Callable[[str], AggregateFnV2]],
  37. post_process_fn: Callable = lambda x: x,
  38. column: Optional[str] = None,
  39. batch_format: Optional[BatchFormat] = None,
  40. ):
  41. super().__init__(
  42. stat_fn=aggregator_fn,
  43. post_process_fn=post_process_fn,
  44. )
  45. self.column = column
  46. self.batch_format = batch_format
  47. class CallableStatSpec(BaseStatSpec):
  48. """Represents a user-defined stat function that operates outside Dataset.aggregate."""
  49. def __init__(
  50. self,
  51. *,
  52. stat_fn: Callable,
  53. stat_key_fn: Optional[Callable[[str], str]],
  54. post_key_fn: Optional[Callable[[str], str]],
  55. post_process_fn: Callable = lambda x: x,
  56. columns: List[str],
  57. ):
  58. super().__init__(
  59. stat_fn=stat_fn,
  60. post_process_fn=post_process_fn,
  61. )
  62. self.columns = columns
  63. self.stat_key_fn = stat_key_fn
  64. self.post_key_fn = post_key_fn
  65. class StatComputationPlan:
  66. """
  67. Encapsulates a set of aggregators (AggregateFnV2) and legacy stat functions
  68. to compute statistics over a Ray dataset.
  69. Supports two types of aggregations:
  70. 1. AggregateFnV2-based aggregators, which are batch-executed using `Dataset.aggregate(...)`.
  71. 2. Callable-based stat functions, executed sequentially (legacy use case).
  72. """
  73. def __init__(self):
  74. self._aggregators: Deque[BaseStatSpec] = deque()
  75. def reset(self):
  76. self._aggregators.clear()
  77. def add_aggregator(
  78. self,
  79. *,
  80. aggregator_fn: Callable[[str], AggregateFnV2],
  81. post_process_fn: Callable = lambda x: x,
  82. columns: List[str],
  83. batch_format: Optional[BatchFormat] = None,
  84. ) -> None:
  85. """
  86. Registers an AggregateFnV2 factory for one or more columns.
  87. Args:
  88. aggregator_fn: A callable (typically a lambda or class) that accepts a column name and returns an instance of AggregateFnV2.
  89. The aggregator should set its name using alias_name parameter to control the output key.
  90. post_process_fn: Function to post-process the aggregated result.
  91. columns: List of column names to aggregate.
  92. batch_format: The batch format for aggregation results. If ARROW, results
  93. are kept in Arrow format for post_process_fn. Otherwise,
  94. results are converted to Python/pandas format.
  95. """
  96. for column in columns:
  97. agg_instance = aggregator_fn(column)
  98. self._aggregators.append(
  99. AggregateStatSpec(
  100. aggregator_fn=agg_instance,
  101. post_process_fn=post_process_fn,
  102. column=column,
  103. batch_format=batch_format,
  104. )
  105. )
  106. def add_callable_stat(
  107. self,
  108. *,
  109. stat_fn: Callable[[], Any],
  110. stat_key_fn: Callable[[str], str],
  111. post_key_fn: Optional[Callable[[str], str]] = None,
  112. post_process_fn: Callable = lambda x: x,
  113. columns: List[str],
  114. ) -> None:
  115. """
  116. Registers a custom stat function to be run sequentially.
  117. This supports legacy use cases where arbitrary callables are needed
  118. and cannot be run via Dataset.aggregate().
  119. Args:
  120. stat_fn: A zero-argument callable that returns the stat.
  121. stat_key_fn: A callable that takes a column name and returns the key for the stat.
  122. post_key_fn: Optional; a callable to post-process the key. If not provided, stat_key_fn is used.
  123. post_process_fn: Function to post-process the result.
  124. columns: List of column names to compute the stat for.
  125. """
  126. self._aggregators.append(
  127. CallableStatSpec(
  128. stat_fn=stat_fn,
  129. post_process_fn=post_process_fn,
  130. columns=columns,
  131. stat_key_fn=stat_key_fn,
  132. post_key_fn=post_key_fn or stat_key_fn,
  133. )
  134. )
  135. def compute(self, dataset: "Dataset") -> Dict[str, Any]:
  136. """
  137. Executes all registered aggregators and stat functions.
  138. AggregateFnV2-based aggregators are batched and executed via Dataset.aggregate().
  139. Callable-based stat functions are run sequentially.
  140. Args:
  141. dataset: The Ray Dataset to compute statistics on.
  142. Returns:
  143. A dictionary of computed statistics.
  144. """
  145. stats = {}
  146. # Run batched aggregators (AggregateFnV2)
  147. aggregators = self._get_aggregate_fn_list()
  148. if aggregators:
  149. agg_ds = dataset.groupby(None).aggregate(*aggregators)
  150. arrow_refs = agg_ds.to_arrow_refs()
  151. if not arrow_refs:
  152. raise ValueError("Aggregation returned no results")
  153. arrow_table = ray.get(arrow_refs[0])
  154. for spec in self._get_aggregate_specs():
  155. stat_key = spec.stat_fn.name
  156. # Aggregation returns single row - extract the scalar value
  157. # ChunkedArray[0] handles multi-chunk arrays automatically
  158. agg_result = arrow_table.column(stat_key)[0]
  159. # Convert to appropriate format based on batch_format
  160. if spec.batch_format == BatchFormat.ARROW:
  161. # Pass Arrow scalar (e.g., ListScalar) for Arrow-optimized post-processing
  162. stats[stat_key] = spec.post_process_fn(agg_result)
  163. else:
  164. # Convert to Python for pandas-style post-processing
  165. stats[stat_key] = spec.post_process_fn(agg_result.as_py())
  166. # Run sequential stat functions
  167. for spec in self._get_custom_stat_fn_specs():
  168. result = spec.stat_fn(spec.stat_key_fn)
  169. for col in spec.columns:
  170. stat_key = spec.stat_key_fn(col)
  171. post_key = spec.post_key_fn(col)
  172. stats[post_key] = spec.post_process_fn(result[stat_key])
  173. return stats
  174. def _get_aggregate_fn_list(self) -> List[AggregateFnV2]:
  175. return [
  176. spec.stat_fn
  177. for spec in self._aggregators
  178. if isinstance(spec, AggregateStatSpec)
  179. ]
  180. def _get_aggregate_specs(self) -> List[AggregateStatSpec]:
  181. return [
  182. spec for spec in self._aggregators if isinstance(spec, AggregateStatSpec)
  183. ]
  184. def _get_custom_stat_fn_specs(self) -> List[CallableStatSpec]:
  185. return [
  186. spec for spec in self._aggregators if isinstance(spec, CallableStatSpec)
  187. ]
  188. def has_custom_stat_fn(self):
  189. return len(self._get_custom_stat_fn_specs()) > 0
  190. def __iter__(self):
  191. """
  192. Iterates over all AggregatorSpecs.
  193. """
  194. return iter(self._get_aggregate_specs())
  195. def make_post_processor(base_fn, callbacks: List[Callable]):
  196. """
  197. Wraps a base post-processing function with a sequence of callback functions.
  198. Useful when multiple post-processing steps need to be applied in order.
  199. """
  200. def wrapper(result):
  201. processed = base_fn(result)
  202. for cb in callbacks:
  203. processed = cb(processed)
  204. return processed
  205. return wrapper