aggregate.py 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777
  1. import abc
  2. import enum
  3. import math
  4. import pickle
  5. import re
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. Collection,
  11. Dict,
  12. Generic,
  13. List,
  14. Optional,
  15. Protocol,
  16. Set,
  17. TypeVar,
  18. Union,
  19. )
  20. import numpy as np
  21. import pyarrow.compute as pc
  22. from ray.data._internal.util import is_null
  23. from ray.data.block import (
  24. Block,
  25. BlockAccessor,
  26. BlockColumn,
  27. BlockColumnAccessor,
  28. KeyType,
  29. )
  30. from ray.util.annotations import Deprecated, PublicAPI
  31. if TYPE_CHECKING:
  32. from ray.data.dataset import Schema
  33. class _SupportsRichComparison(Protocol):
  34. def __lt__(self, other: Any) -> bool:
  35. ...
  36. def __le__(self, other: Any) -> bool:
  37. ...
  38. def __gt__(self, other: Any) -> bool:
  39. ...
  40. def __ge__(self, other: Any) -> bool:
  41. ...
  42. AccumulatorType = TypeVar("AccumulatorType")
  43. SupportsRichComparisonType = TypeVar(
  44. "SupportsRichComparisonType", bound=_SupportsRichComparison
  45. )
  46. AggOutputType = TypeVar("AggOutputType")
  47. _AGGREGATION_NAME_PATTERN = re.compile(r"^([^(]+)(?:\(.*\))?$")
  48. @Deprecated(message="AggregateFn is deprecated, please use AggregateFnV2")
  49. @PublicAPI
  50. class AggregateFn:
  51. """NOTE: THIS IS DEPRECATED, PLEASE USE :class:`AggregateFnV2` INSTEAD
  52. Defines how to perform a custom aggregation in Ray Data.
  53. `AggregateFn` instances are passed to a Dataset's ``.aggregate(...)`` method to
  54. specify the steps required to transform and combine rows sharing the same key.
  55. This enables implementing custom aggregators beyond the standard
  56. built-in options like Sum, Min, Max, Mean, etc.
  57. Args:
  58. init: Function that creates an initial aggregator for each group. Receives a key
  59. (the group key) and returns the initial accumulator state (commonly 0,
  60. an empty list, or an empty dictionary).
  61. merge: Function that merges two accumulators generated by different workers
  62. into one accumulator.
  63. name: An optional display name for the aggregator. Useful for debugging.
  64. accumulate_row: Function that processes an individual row. It receives the current
  65. accumulator and a row, then returns an updated accumulator. Cannot be
  66. used if `accumulate_block` is provided.
  67. accumulate_block: Function that processes an entire block of rows at once. It receives the
  68. current accumulator and a block of rows, then returns an updated accumulator.
  69. This allows for vectorized operations. Cannot be used if `accumulate_row`
  70. is provided.
  71. finalize: Function that finishes the aggregation by transforming the final
  72. accumulator state into the desired output. For example, if your
  73. accumulator is a list of items, you may want to compute a statistic
  74. from the list. If not provided, the final accumulator state is returned
  75. as-is.
  76. Example:
  77. .. testcode::
  78. import ray
  79. from ray.data.aggregate import AggregateFn
  80. # A simple aggregator that counts how many rows there are per group
  81. count_agg = AggregateFn(
  82. init=lambda k: 0,
  83. accumulate_row=lambda counter, row: counter + 1,
  84. merge=lambda c1, c2: c1 + c2,
  85. name="custom_count"
  86. )
  87. ds = ray.data.from_items([{"group": "A"}, {"group": "B"}, {"group": "A"}])
  88. result = ds.groupby("group").aggregate(count_agg).take_all()
  89. # result: [{'group': 'A', 'custom_count': 2}, {'group': 'B', 'custom_count': 1}]
  90. """
  91. def __init__(
  92. self,
  93. init: Callable[[KeyType], AccumulatorType],
  94. merge: Callable[[AccumulatorType, AccumulatorType], AccumulatorType],
  95. name: str,
  96. accumulate_row: Callable[
  97. [AccumulatorType, Dict[str, Any]], AccumulatorType
  98. ] = None,
  99. accumulate_block: Callable[[AccumulatorType, Block], AccumulatorType] = None,
  100. finalize: Optional[Callable[[AccumulatorType], AggOutputType]] = None,
  101. ):
  102. if (accumulate_row is None and accumulate_block is None) or (
  103. accumulate_row is not None and accumulate_block is not None
  104. ):
  105. raise ValueError(
  106. "Exactly one of accumulate_row or accumulate_block must be provided."
  107. )
  108. if accumulate_block is None:
  109. def accumulate_block(a: AccumulatorType, block: Block) -> AccumulatorType:
  110. block_acc = BlockAccessor.for_block(block)
  111. for r in block_acc.iter_rows(public_row_format=False):
  112. a = accumulate_row(a, r)
  113. return a
  114. if not isinstance(name, str):
  115. raise TypeError("`name` must be provided.")
  116. if finalize is None:
  117. finalize = lambda a: a # noqa: E731
  118. self.name = name
  119. self.init = init
  120. self.merge = merge
  121. self.accumulate_block = accumulate_block
  122. self.finalize = finalize
  123. def _validate(self, schema: Optional["Schema"]) -> None:
  124. """Raise an error if this cannot be applied to the given schema."""
  125. pass
  126. @PublicAPI(stability="alpha")
  127. class AggregateFnV2(AggregateFn, abc.ABC, Generic[AccumulatorType, AggOutputType]):
  128. """Provides an interface to implement efficient aggregations to be applied
  129. to the dataset.
  130. `AggregateFnV2` instances are passed to a Dataset's ``.aggregate(...)`` method to
  131. perform distributed aggregations. To create a custom aggregation, you should subclass
  132. `AggregateFnV2` and implement the `aggregate_block` and `combine` methods.
  133. The `finalize` method can also be overridden if the final accumulated state
  134. needs further transformation.
  135. Aggregation follows these steps:
  136. 1. **Initialization**: For each group (if grouping) or for the entire dataset,
  137. an initial accumulator is created using `zero_factory`.
  138. 2. **Block Aggregation**: The `aggregate_block` method is applied to
  139. each block independently, producing a partial aggregation result for that block.
  140. 3. **Combination**: The `combine` method is used to merge these partial
  141. results (or an existing accumulated result with a new partial result)
  142. into a single, combined accumulator.
  143. 4. **Finalization**: Optionally, the `finalize` method transforms the
  144. final combined accumulator into the desired output format.
  145. Generic Type Parameters:
  146. This class is parameterized by two type variables:
  147. - ``AccumulatorType``: The type of the intermediate state (accumulator) used
  148. during aggregation. This is what `aggregate_block` returns, what `combine`
  149. takes as inputs and returns, and what `finalize` receives. For simple
  150. aggregations like `Sum`, this might just be a numeric type. For more complex
  151. aggregations like `Mean`, this could be a composite type like
  152. ``List[Union[int, float]]`` representing ``[sum, count]``.
  153. - ``AggOutputType``: The type of the final result after `finalize` is called.
  154. This is what gets written to the output dataset. For `Sum`, this is the
  155. same as the accumulator type (a number). For `Mean`, the accumulator is
  156. ``[sum, count]`` but the output is a single ``float`` (the computed mean).
  157. Examples of type parameterization in built-in aggregations::
  158. Count(AggregateFnV2[int, int]) # accumulator: int, output: int
  159. Sum(AggregateFnV2[Union[int, float], ...]) # accumulator: number, output: number
  160. Mean(AggregateFnV2[List[...], float]) # accumulator: [sum, count], output: float
  161. Std(AggregateFnV2[List[...], float]) # accumulator: [M2, mean, count], output: float
  162. Args:
  163. name: The name of the aggregation. This will be used as the column name
  164. in the output, e.g., "sum(my_col)".
  165. zero_factory: A callable that returns the initial "zero" value for the
  166. accumulator. For example, for a sum, this would be `lambda: 0`; for
  167. finding a minimum, `lambda: float("inf")`, for finding a maximum,
  168. `lambda: float("-inf")`.
  169. on: The name of the column to perform the aggregation on. If `None`,
  170. the aggregation is performed over the entire row (e.g., for `Count()`).
  171. ignore_nulls: Whether to ignore null values during aggregation.
  172. If `True`, nulls are skipped.
  173. If `False`, the presence of a null value might result in a null output,
  174. depending on the aggregation logic.
  175. """
  176. def __init__(
  177. self,
  178. name: str,
  179. zero_factory: Callable[[], AccumulatorType],
  180. *,
  181. on: Optional[str],
  182. ignore_nulls: bool,
  183. ):
  184. if not name:
  185. raise ValueError(
  186. f"Non-empty string has to be provided as name (got {name})"
  187. )
  188. self._target_col_name = on
  189. self._ignore_nulls = ignore_nulls
  190. # Extract and store the agg name (e.g., "sum" from "sum(col)")
  191. # This avoids string parsing later
  192. match = _AGGREGATION_NAME_PATTERN.match(name)
  193. if match:
  194. self._agg_name = match.group(1)
  195. else:
  196. self._agg_name = name
  197. _safe_combine = _null_safe_combine(self.combine, ignore_nulls)
  198. _safe_aggregate = _null_safe_aggregate(self.aggregate_block, ignore_nulls)
  199. _safe_finalize = _null_safe_finalize(self.finalize)
  200. _safe_zero_factory = _null_safe_zero_factory(zero_factory, ignore_nulls)
  201. super().__init__(
  202. name=name,
  203. init=_safe_zero_factory,
  204. merge=_safe_combine,
  205. accumulate_block=lambda _, block: _safe_aggregate(block),
  206. finalize=_safe_finalize,
  207. )
  208. def get_target_column(self) -> Optional[str]:
  209. return self._target_col_name
  210. def get_agg_name(self) -> str:
  211. """Return the agg name (e.g., 'sum', 'mean', 'count').
  212. Returns the aggregation type extracted from the name during initialization.
  213. For example, returns 'sum' for an aggregator named 'sum(col)'.
  214. """
  215. return self._agg_name
  216. @abc.abstractmethod
  217. def combine(
  218. self, current_accumulator: AccumulatorType, new: AccumulatorType
  219. ) -> AccumulatorType:
  220. """Combines a new partial aggregation result with the current accumulator.
  221. This method defines how two intermediate aggregation states are merged.
  222. For example, if `aggregate_block` produces partial sums `s1` and `s2` from
  223. two different blocks, `combine(s1, s2)` should return `s1 + s2`.
  224. Args:
  225. current_accumulator: The current accumulated state (e.g., the result of
  226. previous `combine` calls or an initial value from `zero_factory`).
  227. new: A new partially aggregated value, typically the output of
  228. `aggregate_block` from a new block of data, or another accumulator
  229. from a parallel task.
  230. Returns:
  231. The updated accumulator after combining it with the new value.
  232. """
  233. ...
  234. @abc.abstractmethod
  235. def aggregate_block(self, block: Block) -> AccumulatorType:
  236. """Aggregates data within a single block.
  237. This method processes all rows in a given `Block` and returns a partial
  238. aggregation result for that block. For instance, if implementing a sum,
  239. this method would sum all relevant values within the block.
  240. Args:
  241. block: A `Block` of data to be aggregated.
  242. Returns:
  243. A partial aggregation result for the input block. The type of this
  244. result (`AggType`) should be consistent with the `current_accumulator`
  245. and `new` arguments of the `combine` method, and the `accumulator`
  246. argument of the `finalize` method.
  247. """
  248. ...
  249. def finalize(self, accumulator: AccumulatorType) -> Optional[AggOutputType]:
  250. """Transforms the final accumulated state into the desired output.
  251. This method is called once per group after all blocks have been processed
  252. and all partial results have been combined. It provides an opportunity
  253. to perform a final transformation on the accumulated data.
  254. For many aggregations (e.g., Sum, Count, Min, Max), the accumulated state
  255. is already the final result, so this method can simply return the
  256. accumulator as is (which is the default behavior).
  257. For other aggregations, like Mean, this method is crucial.
  258. A Mean aggregation might accumulate `[sum, count]`. The `finalize`
  259. method would then compute `sum / count` to get the final mean.
  260. Args:
  261. accumulator: The final accumulated state for a group, after all
  262. `aggregate_block` and `combine` operations.
  263. Returns:
  264. The final result of the aggregation for the group.
  265. """
  266. return accumulator
  267. def _validate(self, schema: Optional["Schema"]) -> None:
  268. if self._target_col_name:
  269. from ray.data._internal.planner.exchange.sort_task_spec import SortKey
  270. SortKey(self._target_col_name).validate_schema(schema)
  271. @PublicAPI
  272. class Count(AggregateFnV2[int, int]):
  273. """Defines count aggregation.
  274. Example:
  275. .. testcode::
  276. import ray
  277. from ray.data.aggregate import Count
  278. ds = ray.data.range(100)
  279. # Schema: {'id': int64}
  280. ds = ds.add_column("group_key", lambda x: x % 3)
  281. # Schema: {'id': int64, 'group_key': int64}
  282. # Counting all rows:
  283. result = ds.aggregate(Count())
  284. # result: {'count()': 100}
  285. # Counting all rows per group:
  286. result = ds.groupby("group_key").aggregate(Count(on="id")).take_all()
  287. # result: [{'group_key': 0, 'count(id)': 34},
  288. # {'group_key': 1, 'count(id)': 33},
  289. # {'group_key': 2, 'count(id)': 33}]
  290. Args:
  291. on: Optional name of the column to count values on. If None, counts rows.
  292. ignore_nulls: Whether to ignore null values when counting. Only applies if
  293. `on` is specified. Default is `False` which means `Count()` on a column
  294. will count nulls by default. To match pandas default behavior of not counting nulls,
  295. set `ignore_nulls=True`.
  296. alias_name: Optional name for the resulting column.
  297. """
  298. def __init__(
  299. self,
  300. on: Optional[str] = None,
  301. ignore_nulls: bool = False,
  302. alias_name: Optional[str] = None,
  303. ):
  304. super().__init__(
  305. alias_name if alias_name else f"count({on or ''})",
  306. on=on,
  307. ignore_nulls=ignore_nulls,
  308. zero_factory=lambda: 0,
  309. )
  310. def aggregate_block(self, block: Block) -> int:
  311. block_accessor = BlockAccessor.for_block(block)
  312. if self._target_col_name is None:
  313. # In case of global count, simply fetch number of rows
  314. return block_accessor.num_rows()
  315. return block_accessor.count(
  316. self._target_col_name, ignore_nulls=self._ignore_nulls
  317. )
  318. def combine(self, current_accumulator: int, new: int) -> int:
  319. return current_accumulator + new
  320. @PublicAPI
  321. class AsList(AggregateFnV2[List, List]):
  322. """Listing aggregation combining all values within the group into a single
  323. list element.
  324. Example:
  325. .. testcode::
  326. :skipif: True
  327. # Skip testing b/c this example require proper ordering of the output
  328. # to be robust and not flaky
  329. import ray
  330. from ray.data.aggregate import AsList
  331. ds = ray.data.range(10)
  332. # Schema: {'id': int64}
  333. ds = ds.add_column("group_key", lambda x: x % 3)
  334. # Schema: {'id': int64, 'group_key': int64}
  335. # Listing all elements per group:
  336. result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all()
  337. # result: [{'group_key': 0, 'list(id)': [0, 3, 6, 9]},
  338. # {'group_key': 1, 'list(id)': [1, 4, 7]},
  339. # {'group_key': 2, 'list(id)': [2, 5, 8]}
  340. Args:
  341. on: The name of the column to collect values from. Must be provided.
  342. alias_name: Optional name for the resulting column.
  343. ignore_nulls: Whether to ignore null values when collecting. If `True`,
  344. nulls are skipped. If `False` (default), nulls are included in the list.
  345. """
  346. def __init__(
  347. self,
  348. on: str,
  349. alias_name: Optional[str] = None,
  350. ignore_nulls: bool = False,
  351. ):
  352. super().__init__(
  353. alias_name if alias_name else f"list({on or ''})",
  354. on=on,
  355. ignore_nulls=ignore_nulls,
  356. zero_factory=lambda: [],
  357. )
  358. def aggregate_block(self, block: Block) -> AccumulatorType:
  359. column_accessor = BlockColumnAccessor.for_column(
  360. block[self.get_target_column()]
  361. )
  362. if self._ignore_nulls:
  363. column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna())
  364. return column_accessor.to_pylist()
  365. def combine(
  366. self, current_accumulator: AccumulatorType, new: AccumulatorType
  367. ) -> AccumulatorType:
  368. return current_accumulator + new
  369. @PublicAPI
  370. class Sum(AggregateFnV2[Union[int, float], Union[int, float]]):
  371. """Defines sum aggregation.
  372. Example:
  373. .. testcode::
  374. import ray
  375. from ray.data.aggregate import Sum
  376. ds = ray.data.range(100)
  377. # Schema: {'id': int64}
  378. ds = ds.add_column("group_key", lambda x: x % 3)
  379. # Schema: {'id': int64, 'group_key': int64}
  380. # Summing all rows per group:
  381. result = ds.aggregate(Sum(on="id"))
  382. # result: {'sum(id)': 4950}
  383. Args:
  384. on: The name of the numerical column to sum. Must be provided.
  385. ignore_nulls: Whether to ignore null values during summation. If `True` (default),
  386. nulls are skipped. If `False`, the sum will be null if any
  387. value in the group is null.
  388. alias_name: Optional name for the resulting column.
  389. """
  390. def __init__(
  391. self,
  392. on: Optional[str] = None,
  393. ignore_nulls: bool = True,
  394. alias_name: Optional[str] = None,
  395. ):
  396. super().__init__(
  397. alias_name if alias_name else f"sum({str(on)})",
  398. on=on,
  399. ignore_nulls=ignore_nulls,
  400. zero_factory=lambda: 0,
  401. )
  402. def aggregate_block(self, block: Block) -> Union[int, float]:
  403. return BlockAccessor.for_block(block).sum(
  404. self._target_col_name, self._ignore_nulls
  405. )
  406. def combine(
  407. self, current_accumulator: Union[int, float], new: Union[int, float]
  408. ) -> Union[int, float]:
  409. return current_accumulator + new
  410. @PublicAPI
  411. class Min(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
  412. """Defines min aggregation.
  413. Example:
  414. .. testcode::
  415. import ray
  416. from ray.data.aggregate import Min
  417. ds = ray.data.range(100)
  418. # Schema: {'id': int64}
  419. ds = ds.add_column("group_key", lambda x: x % 3)
  420. # Schema: {'id': int64, 'group_key': int64}
  421. # Finding the minimum value per group:
  422. result = ds.groupby("group_key").aggregate(Min(on="id")).take_all()
  423. # result: [{'group_key': 0, 'min(id)': 0},
  424. # {'group_key': 1, 'min(id)': 1},
  425. # {'group_key': 2, 'min(id)': 2}]
  426. Args:
  427. on: The name of the column to find the minimum value from. Must be provided.
  428. ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
  429. skipped. If `False`, the minimum will be null if any value in
  430. the group is null (for most data types, or follow type-specific
  431. comparison rules with nulls).
  432. alias_name: Optional name for the resulting column.
  433. zero_factory: A callable that returns the initial "zero" value for the
  434. accumulator. For example, for a float column, this would be
  435. `lambda: float("+inf")`. Default is `lambda: float("+inf")`.
  436. """
  437. def __init__(
  438. self,
  439. on: Optional[str] = None,
  440. ignore_nulls: bool = True,
  441. alias_name: Optional[str] = None,
  442. zero_factory: Callable[[], SupportsRichComparisonType] = lambda: float("+inf"),
  443. ):
  444. super().__init__(
  445. alias_name if alias_name else f"min({str(on)})",
  446. on=on,
  447. ignore_nulls=ignore_nulls,
  448. zero_factory=zero_factory,
  449. )
  450. def aggregate_block(self, block: Block) -> SupportsRichComparisonType:
  451. return BlockAccessor.for_block(block).min(
  452. self._target_col_name, self._ignore_nulls
  453. )
  454. def combine(
  455. self,
  456. current_accumulator: SupportsRichComparisonType,
  457. new: SupportsRichComparisonType,
  458. ) -> SupportsRichComparisonType:
  459. return min(current_accumulator, new)
  460. @PublicAPI
  461. class Max(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
  462. """Defines max aggregation.
  463. Example:
  464. .. testcode::
  465. import ray
  466. from ray.data.aggregate import Max
  467. ds = ray.data.range(100)
  468. # Schema: {'id': int64}
  469. ds = ds.add_column("group_key", lambda x: x % 3)
  470. # Schema: {'id': int64, 'group_key': int64}
  471. # Finding the maximum value per group:
  472. result = ds.groupby("group_key").aggregate(Max(on="id")).take_all()
  473. # result: [{'group_key': 0, 'max(id)': ...},
  474. # {'group_key': 1, 'max(id)': ...},
  475. # {'group_key': 2, 'max(id)': ...}]
  476. Args:
  477. on: The name of the column to find the maximum value from. Must be provided.
  478. ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
  479. skipped. If `False`, the maximum will be null if any value in
  480. the group is null (for most data types, or follow type-specific
  481. comparison rules with nulls).
  482. alias_name: Optional name for the resulting column.
  483. zero_factory: A callable that returns the initial "zero" value for the
  484. accumulator. For example, for a float column, this would be
  485. `lambda: float("-inf")`. Default is `lambda: float("-inf")`.
  486. """
  487. def __init__(
  488. self,
  489. on: Optional[str] = None,
  490. ignore_nulls: bool = True,
  491. alias_name: Optional[str] = None,
  492. zero_factory: Callable[[], SupportsRichComparisonType] = lambda: float("-inf"),
  493. ):
  494. super().__init__(
  495. alias_name if alias_name else f"max({str(on)})",
  496. on=on,
  497. ignore_nulls=ignore_nulls,
  498. zero_factory=zero_factory,
  499. )
  500. def aggregate_block(self, block: Block) -> SupportsRichComparisonType:
  501. return BlockAccessor.for_block(block).max(
  502. self._target_col_name, self._ignore_nulls
  503. )
  504. def combine(
  505. self,
  506. current_accumulator: SupportsRichComparisonType,
  507. new: SupportsRichComparisonType,
  508. ) -> SupportsRichComparisonType:
  509. return max(current_accumulator, new)
  510. @PublicAPI
  511. class Mean(AggregateFnV2[List[Union[int, float]], float]):
  512. """Defines mean (average) aggregation.
  513. Example:
  514. .. testcode::
  515. import ray
  516. from ray.data.aggregate import Mean
  517. ds = ray.data.range(100)
  518. # Schema: {'id': int64}
  519. ds = ds.add_column("group_key", lambda x: x % 3)
  520. # Schema: {'id': int64, 'group_key': int64}
  521. # Calculating the mean value per group:
  522. result = ds.groupby("group_key").aggregate(Mean(on="id")).take_all()
  523. # result: [{'group_key': 0, 'mean(id)': ...},
  524. # {'group_key': 1, 'mean(id)': ...},
  525. # {'group_key': 2, 'mean(id)': ...}]
  526. Args:
  527. on: The name of the numerical column to calculate the mean on. Must be provided.
  528. ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
  529. skipped. If `False`, the mean will be null if any value in the
  530. group is null.
  531. alias_name: Optional name for the resulting column.
  532. """
  533. def __init__(
  534. self,
  535. on: Optional[str] = None,
  536. ignore_nulls: bool = True,
  537. alias_name: Optional[str] = None,
  538. ):
  539. super().__init__(
  540. alias_name if alias_name else f"mean({str(on)})",
  541. on=on,
  542. ignore_nulls=ignore_nulls,
  543. # The accumulator is: [current_sum, current_count].
  544. # NOTE: We copy the returned list `list([0,0])` as some internal mechanisms
  545. # might modify accumulators in-place.
  546. zero_factory=lambda: list([0, 0]), # noqa: C410
  547. )
  548. def aggregate_block(self, block: Block) -> Optional[List[Union[int, float]]]:
  549. block_acc = BlockAccessor.for_block(block)
  550. count = block_acc.count(self._target_col_name, self._ignore_nulls)
  551. if count == 0 or count is None:
  552. # Empty or all null.
  553. return None
  554. sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
  555. if is_null(sum_):
  556. # In case of ignore_nulls=False and column containing 'null'
  557. # return as is (to prevent unnecessary type conversions, when, for ex,
  558. # using Pandas and returning None)
  559. return sum_
  560. return [sum_, count]
  561. def combine(
  562. self, current_accumulator: List[Union[int, float]], new: List[Union[int, float]]
  563. ) -> List[Union[int, float]]:
  564. return [current_accumulator[0] + new[0], current_accumulator[1] + new[1]]
  565. def finalize(self, accumulator: List[Union[int, float]]) -> Optional[float]:
  566. # The final accumulator for a group is [total_sum, total_count].
  567. if accumulator[1] == 0:
  568. # If total_count is 0 (e.g., group was empty or all nulls ignored),
  569. # the mean is undefined. Return NaN
  570. return np.nan
  571. return accumulator[0] / accumulator[1]
  572. @PublicAPI
  573. class Std(AggregateFnV2[List[Union[int, float]], float]):
  574. """Defines standard deviation aggregation.
  575. Uses Welford's online algorithm for numerical stability. This method computes
  576. the standard deviation in a single pass. Results may differ slightly from
  577. libraries like NumPy or Pandas that use a two-pass algorithm but are generally
  578. more accurate.
  579. See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
  580. Example:
  581. .. testcode::
  582. import ray
  583. from ray.data.aggregate import Std
  584. ds = ray.data.range(100)
  585. # Schema: {'id': int64}
  586. ds = ds.add_column("group_key", lambda x: x % 3)
  587. # Schema: {'id': int64, 'group_key': int64}
  588. # Calculating the standard deviation per group:
  589. result = ds.groupby("group_key").aggregate(Std(on="id")).take_all()
  590. # result: [{'group_key': 0, 'std(id)': ...},
  591. # {'group_key': 1, 'std(id)': ...},
  592. # {'group_key': 2, 'std(id)': ...}]
  593. Args:
  594. on: The name of the column to calculate standard deviation on.
  595. ddof: Delta Degrees of Freedom. The divisor used in calculations is `N - ddof`,
  596. where `N` is the number of elements. Default is 1.
  597. ignore_nulls: Whether to ignore null values. Default is True.
  598. alias_name: Optional name for the resulting column.
  599. """
  600. def __init__(
  601. self,
  602. on: Optional[str] = None,
  603. ddof: int = 1,
  604. ignore_nulls: bool = True,
  605. alias_name: Optional[str] = None,
  606. ):
  607. super().__init__(
  608. alias_name if alias_name else f"std({str(on)})",
  609. on=on,
  610. ignore_nulls=ignore_nulls,
  611. # Accumulator: [M2, mean, count]
  612. # M2: sum of squares of differences from the current mean
  613. # mean: current mean
  614. # count: current count of non-null elements
  615. # We need to copy the list as it might be modified in-place by some aggregations.
  616. zero_factory=lambda: list([0, 0, 0]), # noqa: C410
  617. )
  618. self._ddof = ddof
  619. def aggregate_block(self, block: Block) -> List[Union[int, float]]:
  620. block_acc = BlockAccessor.for_block(block)
  621. count = block_acc.count(self._target_col_name, ignore_nulls=self._ignore_nulls)
  622. if count == 0 or count is None:
  623. # Empty or all null.
  624. return None
  625. sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
  626. if is_null(sum_):
  627. # If sum is null (e.g., ignore_nulls=False and a null was encountered),
  628. # return as is to prevent type conversions.
  629. return sum_
  630. mean = sum_ / count
  631. M2 = block_acc.sum_of_squared_diffs_from_mean(
  632. self._target_col_name, self._ignore_nulls, mean
  633. )
  634. return [M2, mean, count]
  635. def combine(
  636. self, current_accumulator: List[float], new: List[float]
  637. ) -> List[float]:
  638. # Merges two accumulators [M2, mean, count] using a parallel algorithm.
  639. # See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
  640. M2_a, mean_a, count_a = current_accumulator
  641. M2_b, mean_b, count_b = new
  642. delta = mean_b - mean_a
  643. count = count_a + count_b
  644. # NOTE: We use this mean calculation since it's more numerically
  645. # stable than mean_a + delta * count_b / count, which actually
  646. # deviates from Pandas in the ~15th decimal place and causes our
  647. # exact comparison tests to fail.
  648. mean = (mean_a * count_a + mean_b * count_b) / count
  649. # Update the sum of squared differences.
  650. M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
  651. return [M2, mean, count]
  652. def finalize(self, accumulator: List[float]) -> Optional[float]:
  653. # Compute the final standard deviation from the accumulated
  654. # sum of squared differences from current mean and the count.
  655. # Final accumulator: [M2, mean, count]
  656. M2, mean, count = accumulator
  657. # Denominator for variance calculation is count - ddof
  658. if count - self._ddof <= 0:
  659. # If count - ddof is not positive, variance/std is undefined (or zero).
  660. # Return NaN, consistent with pandas/numpy.
  661. return np.nan
  662. # Standard deviation is the square root of variance (M2 / (count - ddof))
  663. return math.sqrt(M2 / (count - self._ddof))
  664. @PublicAPI
  665. class AbsMax(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
  666. """Defines absolute max aggregation.
  667. Example:
  668. .. testcode::
  669. import ray
  670. from ray.data.aggregate import AbsMax
  671. ds = ray.data.range(100)
  672. # Schema: {'id': int64}
  673. ds = ds.add_column("group_key", lambda x: x % 3)
  674. # Schema: {'id': int64, 'group_key': int64}
  675. # Calculating the absolute maximum value per group:
  676. result = ds.groupby("group_key").aggregate(AbsMax(on="id")).take_all()
  677. # result: [{'group_key': 0, 'abs_max(id)': ...},
  678. # {'group_key': 1, 'abs_max(id)': ...},
  679. # {'group_key': 2, 'abs_max(id)': ...}]
  680. Args:
  681. on: The name of the column to calculate absolute maximum on. Must be provided.
  682. ignore_nulls: Whether to ignore null values. Default is True.
  683. alias_name: Optional name for the resulting column.
  684. zero_factory: A callable that returns the initial "zero" value for the
  685. accumulator. For example, for a float column, this would be
  686. `lambda: 0`. Default is `lambda: 0`.
  687. """
  688. def __init__(
  689. self,
  690. on: Optional[str] = None,
  691. ignore_nulls: bool = True,
  692. alias_name: Optional[str] = None,
  693. zero_factory: Callable[[], SupportsRichComparisonType] = lambda: 0,
  694. ):
  695. if on is None or not isinstance(on, str):
  696. raise ValueError(f"Column to aggregate on has to be provided (got {on})")
  697. super().__init__(
  698. alias_name if alias_name else f"abs_max({str(on)})",
  699. on=on,
  700. ignore_nulls=ignore_nulls,
  701. zero_factory=zero_factory,
  702. )
  703. def aggregate_block(self, block: Block) -> Optional[SupportsRichComparisonType]:
  704. block_accessor = BlockAccessor.for_block(block)
  705. max_ = block_accessor.max(self._target_col_name, self._ignore_nulls)
  706. min_ = block_accessor.min(self._target_col_name, self._ignore_nulls)
  707. if is_null(max_) or is_null(min_):
  708. return None
  709. return max(abs(max_), abs(min_))
  710. def combine(
  711. self,
  712. current_accumulator: SupportsRichComparisonType,
  713. new: SupportsRichComparisonType,
  714. ) -> SupportsRichComparisonType:
  715. return max(current_accumulator, new)
  716. @PublicAPI
  717. class Quantile(AggregateFnV2[List[Any], List[Any]]):
  718. """Defines Quantile aggregation.
  719. Example:
  720. .. testcode::
  721. import ray
  722. from ray.data.aggregate import Quantile
  723. ds = ray.data.range(100)
  724. # Schema: {'id': int64}
  725. ds = ds.add_column("group_key", lambda x: x % 3)
  726. # Schema: {'id': int64, 'group_key': int64}
  727. # Calculating the 50th percentile (median) per group:
  728. result = ds.groupby("group_key").aggregate(Quantile(q=0.5, on="id")).take_all()
  729. # result: [{'group_key': 0, 'quantile(id)': ...},
  730. # {'group_key': 1, 'quantile(id)': ...},
  731. # {'group_key': 2, 'quantile(id)': ...}]
  732. Args:
  733. on: The name of the column to calculate the quantile on. Must be provided.
  734. q: The quantile to compute, which must be between 0 and 1 inclusive.
  735. For example, q=0.5 computes the median.
  736. ignore_nulls: Whether to ignore null values. Default is True.
  737. alias_name: Optional name for the resulting column.
  738. """
  739. def __init__(
  740. self,
  741. on: Optional[str] = None,
  742. q: float = 0.5,
  743. ignore_nulls: bool = True,
  744. alias_name: Optional[str] = None,
  745. ):
  746. self._q = q
  747. super().__init__(
  748. alias_name if alias_name else f"quantile({str(on)})",
  749. on=on,
  750. ignore_nulls=ignore_nulls,
  751. zero_factory=list,
  752. )
  753. def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
  754. if isinstance(current_accumulator, List) and isinstance(new, List):
  755. current_accumulator.extend(new)
  756. return current_accumulator
  757. if isinstance(current_accumulator, List) and (not isinstance(new, List)):
  758. if new is not None and new != "":
  759. current_accumulator.append(new)
  760. return current_accumulator
  761. if isinstance(new, List) and (not isinstance(current_accumulator, List)):
  762. if current_accumulator is not None and current_accumulator != "":
  763. new.append(current_accumulator)
  764. return new
  765. ls = []
  766. if current_accumulator is not None and current_accumulator != "":
  767. ls.append(current_accumulator)
  768. if new is not None and new != "":
  769. ls.append(new)
  770. return ls
  771. def aggregate_block(self, block: Block) -> List[Any]:
  772. block_acc = BlockAccessor.for_block(block)
  773. ls = []
  774. for row in block_acc.iter_rows(public_row_format=False):
  775. ls.append(row.get(self._target_col_name))
  776. return ls
  777. def finalize(self, accumulator: List[Any]) -> Optional[Any]:
  778. if self._ignore_nulls:
  779. accumulator = [v for v in accumulator if not is_null(v)]
  780. else:
  781. nulls = [v for v in accumulator if is_null(v)]
  782. if len(nulls) > 0:
  783. # If nulls are present and not ignored, the quantile is undefined.
  784. # Return the first null encountered to preserve column type.
  785. return nulls[0]
  786. if not accumulator:
  787. # If the list is empty (e.g., all values were null and ignored, or no values),
  788. # quantile is undefined.
  789. return None
  790. key = lambda x: x # noqa: E731
  791. input_values = sorted(accumulator)
  792. k = (len(input_values) - 1) * self._q
  793. f = math.floor(k)
  794. c = math.ceil(k)
  795. if f == c:
  796. return key(input_values[int(k)])
  797. # Interpolate between the elements at floor and ceil indices.
  798. d0 = key(input_values[int(f)]) * (c - k)
  799. d1 = key(input_values[int(c)]) * (k - f)
  800. return round(d0 + d1, 5)
  801. @PublicAPI
  802. class Unique(AggregateFnV2[Set[Any], List[Any]]):
  803. """Defines unique aggregation.
  804. Example:
  805. .. testcode::
  806. import ray
  807. from ray.data.aggregate import Unique
  808. ds = ray.data.range(100)
  809. ds = ds.add_column("group_key", lambda x: x % 3)
  810. # Calculating the unique values per group:
  811. result = ds.groupby("group_key").aggregate(Unique(on="id")).take_all()
  812. # result: [{'group_key': 0, 'unique(id)': ...},
  813. # {'group_key': 1, 'unique(id)': ...},
  814. # {'group_key': 2, 'unique(id)': ...}]
  815. Args:
  816. on: The name of the column from which to collect unique values.
  817. ignore_nulls: Whether to ignore null values when collecting unique items.
  818. Default is True (nulls are excluded).
  819. alias_name: Optional name for the resulting column.
  820. encode_lists: If `True`, encode list elements. If `False`, encode
  821. whole lists (i.e., the entire list is considered as a single object).
  822. `False` by default. Note that this is a top-level flatten (not a recursive
  823. flatten) operation.
  824. """
  825. class ListEncodingMode(str, enum.Enum):
  826. """Controls how to encode individual elements inside the list column:
  827. - NONE: no encoding applied, elements (lists) are stored as is and
  828. unique ones are returned.
  829. - FLATTEN: column of element lists is flattened into a single list.
  830. - HASH: each list element is hashed, a list of unique hashes is returned.
  831. """
  832. FLATTEN = "FLATTEN"
  833. HASH = "HASH"
  834. def __init__(
  835. self,
  836. on: Optional[str] = None,
  837. ignore_nulls: bool = False,
  838. alias_name: Optional[str] = None,
  839. encode_lists: Union[bool, ListEncodingMode, None] = None,
  840. ):
  841. super().__init__(
  842. alias_name if alias_name else f"unique({str(on)})",
  843. on=on,
  844. ignore_nulls=ignore_nulls,
  845. zero_factory=set,
  846. )
  847. if isinstance(encode_lists, Unique.ListEncodingMode):
  848. self._list_encoding_mode = encode_lists
  849. elif isinstance(encode_lists, bool) and encode_lists:
  850. self._list_encoding_mode = Unique.ListEncodingMode.FLATTEN
  851. else:
  852. self._list_encoding_mode = None
  853. def combine(self, current_accumulator: Set[Any], new: Set[Any]) -> Set[Any]:
  854. return self._to_set(current_accumulator) | self._to_set(new)
  855. def _compute_unique(self, block: Block) -> BlockColumn:
  856. column = block[self._target_col_name]
  857. column_accessor = BlockColumnAccessor.for_column(column)
  858. if (
  859. column_accessor.is_composed_of_lists()
  860. and self._list_encoding_mode is not None
  861. ):
  862. if self._list_encoding_mode == Unique.ListEncodingMode.FLATTEN:
  863. column_accessor = BlockColumnAccessor.for_column(
  864. column_accessor.flatten()
  865. )
  866. elif self._list_encoding_mode == Unique.ListEncodingMode.HASH:
  867. column_accessor = BlockColumnAccessor.for_column(column_accessor.hash())
  868. else:
  869. raise ValueError(
  870. f"list encoding mode not supported: {self._list_encoding_mode}"
  871. )
  872. if self._ignore_nulls:
  873. column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna())
  874. return column_accessor.unique()
  875. def aggregate_block(self, block: Block) -> List[Any]:
  876. column = self._compute_unique(block)
  877. return BlockColumnAccessor.for_column(column).to_pylist()
  878. @staticmethod
  879. def _to_set(x):
  880. if isinstance(x, set):
  881. return Unique._normalize_nans(x)
  882. elif isinstance(x, list):
  883. if len(x) > 0 and isinstance(x[0], list):
  884. # necessary because pyarrow converts all tuples to
  885. # list internally.
  886. x = map(lambda v: None if v is None else tuple(v), x)
  887. return Unique._normalize_nans(x)
  888. else:
  889. return {x}
  890. @staticmethod
  891. def _normalize_nans(x: Collection) -> Set:
  892. # NOTE: Pandas when converting to Python objects instantiates
  893. # new `float('nan')` objects which are incomparable b/w each
  894. # other. Here we canonicalize any nan instances replacing them
  895. # w/ `np.nan`
  896. return {v if not (isinstance(v, float) and np.isnan(v)) else np.nan for v in x}
  897. @PublicAPI
  898. class CountDistinct(Unique):
  899. """Defines distinct count aggregation.
  900. This aggregation computes the count of distinct values in a column.
  901. It is similar to SQL's COUNT(DISTINCT column_name) operation.
  902. Example:
  903. .. testcode::
  904. import ray
  905. from ray.data.aggregate import CountDistinct
  906. # Create a dataset with repeated values
  907. ds = ray.data.from_items([
  908. {"category": "A"}, {"category": "B"}, {"category": "A"},
  909. {"category": "C"}, {"category": "A"}, {"category": "B"}
  910. ])
  911. # Count distinct categories
  912. result = ds.aggregate(CountDistinct(on="category"))
  913. # result: {'count_distinct(category)': 3}
  914. # Using with groupby
  915. ds = ray.data.from_items([
  916. {"group": "X", "category": "A"}, {"group": "X", "category": "B"},
  917. {"group": "Y", "category": "A"}, {"group": "Y", "category": "A"}
  918. ])
  919. result = ds.groupby("group").aggregate(CountDistinct(on="category")).take_all()
  920. # result: [{'group': 'X', 'count_distinct(category)': 2},
  921. # {'group': 'Y', 'count_distinct(category)': 1}]
  922. Args:
  923. on: The name of the column to count distinct values on.
  924. ignore_nulls: Whether to ignore null values when counting distinct items.
  925. Default is True (nulls are excluded from the count).
  926. alias_name: Optional name for the resulting column. If not provided,
  927. defaults to "count_distinct({on})".
  928. """
  929. def __init__(
  930. self,
  931. on: str,
  932. ignore_nulls: bool = True,
  933. alias_name: Optional[str] = None,
  934. ):
  935. super().__init__(
  936. on=on,
  937. ignore_nulls=ignore_nulls,
  938. alias_name=alias_name if alias_name else f"count_distinct({str(on)})",
  939. )
  940. def finalize(self, accumulator: Set[Any]) -> int:
  941. """Return the count of distinct values."""
  942. return len(accumulator)
  943. @PublicAPI
  944. class ValueCounter(AggregateFnV2):
  945. """Counts the number of times each value appears in a column.
  946. This aggregation computes value counts for a specified column, similar to pandas'
  947. `value_counts()` method. It returns a dictionary with two lists: "values" containing
  948. the unique values found in the column, and "counts" containing the corresponding
  949. count for each value.
  950. Example:
  951. .. testcode::
  952. import ray
  953. from ray.data.aggregate import ValueCounter
  954. # Create a dataset with repeated values
  955. ds = ray.data.from_items([
  956. {"category": "A"}, {"category": "B"}, {"category": "A"},
  957. {"category": "C"}, {"category": "A"}, {"category": "B"}
  958. ])
  959. # Count occurrences of each category
  960. result = ds.aggregate(ValueCounter(on="category"))
  961. # result: {'value_counter(category)': {'values': ['A', 'B', 'C'], 'counts': [3, 2, 1]}}
  962. # Using with groupby
  963. ds = ray.data.from_items([
  964. {"group": "X", "category": "A"}, {"group": "X", "category": "B"},
  965. {"group": "Y", "category": "A"}, {"group": "Y", "category": "A"}
  966. ])
  967. result = ds.groupby("group").aggregate(ValueCounter(on="category")).take_all()
  968. # result: [{'group': 'X', 'value_counter(category)': {'values': ['A', 'B'], 'counts': [1, 1]}},
  969. # {'group': 'Y', 'value_counter(category)': {'values': ['A'], 'counts': [2]}}]
  970. Args:
  971. on: The name of the column to count values in. Must be provided.
  972. alias_name: Optional name for the resulting column. If not provided,
  973. defaults to "value_counter({column_name})".
  974. """
  975. def __init__(
  976. self,
  977. on: str,
  978. alias_name: Optional[str] = None,
  979. ):
  980. super().__init__(
  981. alias_name if alias_name else f"value_counter({str(on)})",
  982. on=on,
  983. ignore_nulls=True,
  984. zero_factory=lambda: {"values": [], "counts": []},
  985. )
  986. def aggregate_block(self, block: Block) -> Dict[str, List]:
  987. col_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
  988. return col_accessor.value_counts()
  989. def combine(
  990. self,
  991. current_accumulator: Dict[str, List],
  992. new_accumulator: Dict[str, List],
  993. ) -> Dict[str, List]:
  994. values = current_accumulator["values"]
  995. counts = current_accumulator["counts"]
  996. # Build a value → index map once (avoid repeated lookups)
  997. value_to_index = {v: i for i, v in enumerate(values)}
  998. for v_new, c_new in zip(new_accumulator["values"], new_accumulator["counts"]):
  999. if v_new in value_to_index:
  1000. idx = value_to_index[v_new]
  1001. counts[idx] += c_new
  1002. else:
  1003. value_to_index[v_new] = len(values)
  1004. values.append(v_new)
  1005. counts.append(c_new)
  1006. return current_accumulator
  1007. def _null_safe_zero_factory(zero_factory, ignore_nulls: bool):
  1008. """NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
  1009. Null-safe zero factory is crucial for implementing proper aggregation
  1010. protocol (monoid) w/o the need for additional containers.
  1011. Main hurdle for implementing proper aggregation semantic is to be able to encode
  1012. semantic of an "empty accumulator" and be able to tell it from the case when
  1013. accumulator is actually holding null value:
  1014. - Empty container can be overridden with any value
  1015. - Container holding null can't be overridden if ignore_nulls=False
  1016. However, it's possible for us to exploit asymmetry in cases of ignore_nulls being
  1017. True or False:
  1018. - Case of ignore_nulls=False entails that if there's any "null" in the sequence,
  1019. aggregation is undefined and correspondingly expected to return null
  1020. - Case of ignore_nulls=True in turn, entails that if aggregation returns "null"
  1021. if and only if the sequence does NOT have any non-null value
  1022. Therefore, we apply this difference in semantic to zero-factory to make sure that
  1023. our aggregation protocol is adherent to that definition:
  1024. - If ignore_nulls=True, zero-factory returns null, therefore encoding empty
  1025. container
  1026. - If ignore_nulls=False, couldn't return null as aggregation will incorrectly
  1027. prioritize it, and instead it returns true zero value for the aggregation
  1028. (ie 0 for count/sum, -inf for max, etc).
  1029. """
  1030. if ignore_nulls:
  1031. def _safe_zero_factory(_):
  1032. return None
  1033. else:
  1034. def _safe_zero_factory(_):
  1035. return zero_factory()
  1036. return _safe_zero_factory
  1037. def _null_safe_aggregate(
  1038. aggregate: Callable[[Block], AccumulatorType],
  1039. ignore_nulls: bool,
  1040. ) -> Callable[[Block], Optional[AccumulatorType]]:
  1041. def _safe_aggregate(block: Block) -> Optional[AccumulatorType]:
  1042. result = aggregate(block)
  1043. # NOTE: If `ignore_nulls=True`, aggregation will only be returning
  1044. # null if the block does NOT contain any non-null elements
  1045. if is_null(result) and ignore_nulls:
  1046. return None
  1047. return result
  1048. return _safe_aggregate
  1049. def _null_safe_finalize(
  1050. finalize: Callable[[AccumulatorType], AccumulatorType],
  1051. ) -> Callable[[Optional[AccumulatorType]], AccumulatorType]:
  1052. def _safe_finalize(acc: Optional[AccumulatorType]) -> AccumulatorType:
  1053. # If accumulator container is not null, finalize.
  1054. # Otherwise, return as is.
  1055. return acc if is_null(acc) else finalize(acc)
  1056. return _safe_finalize
  1057. def _null_safe_combine(
  1058. combine: Callable[[AccumulatorType, AccumulatorType], AccumulatorType],
  1059. ignore_nulls: bool,
  1060. ) -> Callable[
  1061. [Optional[AccumulatorType], Optional[AccumulatorType]], Optional[AccumulatorType]
  1062. ]:
  1063. """Null-safe combination have to be an associative operation
  1064. with an identity element (zero) or in other words implement a monoid.
  1065. To achieve that in the presence of null values following semantic is
  1066. established:
  1067. - Case of ignore_nulls=True:
  1068. - If current accumulator is null (ie empty), return new accumulator
  1069. - If new accumulator is null (ie empty), return cur
  1070. - Otherwise combine (current and new)
  1071. - Case of ignore_nulls=False:
  1072. - If new accumulator is null (ie has null in the sequence, b/c we're
  1073. NOT ignoring nulls), return it
  1074. - If current accumulator is null (ie had null in the prior sequence,
  1075. b/c we're NOT ignoring nulls), return it
  1076. - Otherwise combine (current and new)
  1077. """
  1078. if ignore_nulls:
  1079. def _safe_combine(
  1080. cur: Optional[AccumulatorType], new: Optional[AccumulatorType]
  1081. ) -> Optional[AccumulatorType]:
  1082. if is_null(cur):
  1083. return new
  1084. elif is_null(new):
  1085. return cur
  1086. else:
  1087. return combine(cur, new)
  1088. else:
  1089. def _safe_combine(
  1090. cur: Optional[AccumulatorType], new: Optional[AccumulatorType]
  1091. ) -> Optional[AccumulatorType]:
  1092. if is_null(new):
  1093. return new
  1094. elif is_null(cur):
  1095. return cur
  1096. else:
  1097. return combine(cur, new)
  1098. return _safe_combine
  1099. @PublicAPI(stability="alpha")
  1100. class MissingValuePercentage(AggregateFnV2[List[int], float]):
  1101. """Calculates the percentage of null values in a column.
  1102. This aggregation computes the percentage of null (missing) values in a dataset column.
  1103. It treats both None values and NaN values as null. The result is a percentage value
  1104. between 0.0 and 100.0, where 0.0 means no missing values and 100.0 means all values
  1105. are missing.
  1106. Example:
  1107. .. testcode::
  1108. import ray
  1109. from ray.data.aggregate import MissingValuePercentage
  1110. # Create a dataset with some missing values
  1111. ds = ray.data.from_items([
  1112. {"value": 1}, {"value": None}, {"value": 3},
  1113. {"value": None}, {"value": 5}
  1114. ])
  1115. # Calculate missing value percentage
  1116. result = ds.aggregate(MissingValuePercentage(on="value"))
  1117. # result: 40.0 (2 out of 5 values are missing)
  1118. # Using with groupby
  1119. ds = ray.data.from_items([
  1120. {"group": "A", "value": 1}, {"group": "A", "value": None},
  1121. {"group": "B", "value": 3}, {"group": "B", "value": None}
  1122. ])
  1123. result = ds.groupby("group").aggregate(MissingValuePercentage(on="value")).take_all()
  1124. # result: [{'group': 'A', 'missing_pct(value)': 50.0},
  1125. # {'group': 'B', 'missing_pct(value)': 50.0}]
  1126. Args:
  1127. on: The name of the column to calculate missing value percentage on.
  1128. alias_name: Optional name for the resulting column. If not provided,
  1129. defaults to "missing_pct({column_name})".
  1130. """
  1131. def __init__(
  1132. self,
  1133. on: str,
  1134. alias_name: Optional[str] = None,
  1135. ):
  1136. # Initialize with a list accumulator [null_count, total_count]
  1137. super().__init__(
  1138. alias_name if alias_name else f"missing_pct({str(on)})",
  1139. on=on,
  1140. ignore_nulls=False, # Include nulls for this calculation
  1141. zero_factory=lambda: [0, 0], # Our AggType is a simple list
  1142. )
  1143. def aggregate_block(self, block: Block) -> List[int]:
  1144. column_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
  1145. total_count = column_accessor.count(ignore_nulls=False)
  1146. null_count = pc.sum(
  1147. pc.is_null(column_accessor._as_arrow_compatible(), nan_is_null=True)
  1148. ).as_py()
  1149. # Return our accumulator
  1150. return [null_count, total_count]
  1151. def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
  1152. # Merge two accumulators by summing their components
  1153. assert len(current_accumulator) == len(new) == 2
  1154. return [
  1155. current_accumulator[0] + new[0], # Sum null counts
  1156. current_accumulator[1] + new[1], # Sum total counts
  1157. ]
  1158. def finalize(self, accumulator: List[int]) -> Optional[float]:
  1159. # Calculate the final percentage
  1160. if accumulator[1] == 0:
  1161. return None
  1162. return (accumulator[0] / accumulator[1]) * 100.0
  1163. @PublicAPI(stability="alpha")
  1164. class ZeroPercentage(AggregateFnV2[List[int], float]):
  1165. """Calculates the percentage of zero values in a numeric column.
  1166. This aggregation computes the percentage of zero values in a numeric dataset column.
  1167. It can optionally ignore null values when calculating the percentage. The result is
  1168. a percentage value between 0.0 and 100.0, where 0.0 means no zero values and 100.0
  1169. means all non-null values are zero.
  1170. Example:
  1171. .. testcode::
  1172. import ray
  1173. from ray.data.aggregate import ZeroPercentage
  1174. # Create a dataset with some zero values
  1175. ds = ray.data.from_items([
  1176. {"value": 0}, {"value": 1}, {"value": 0},
  1177. {"value": 3}, {"value": 0}
  1178. ])
  1179. # Calculate zero value percentage
  1180. result = ds.aggregate(ZeroPercentage(on="value"))
  1181. # result: 60.0 (3 out of 5 values are zero)
  1182. # With null values and ignore_nulls=True (default)
  1183. ds = ray.data.from_items([
  1184. {"value": 0}, {"value": None}, {"value": 0},
  1185. {"value": 3}, {"value": 0}
  1186. ])
  1187. result = ds.aggregate(ZeroPercentage(on="value", ignore_nulls=True))
  1188. # result: 75.0 (3 out of 4 non-null values are zero)
  1189. # Using with groupby
  1190. ds = ray.data.from_items([
  1191. {"group": "A", "value": 0}, {"group": "A", "value": 1},
  1192. {"group": "B", "value": 0}, {"group": "B", "value": 0}
  1193. ])
  1194. result = ds.groupby("group").aggregate(ZeroPercentage(on="value")).take_all()
  1195. # result: [{'group': 'A', 'zero_pct(value)': 50.0},
  1196. # {'group': 'B', 'zero_pct(value)': 100.0}]
  1197. Args:
  1198. on: The name of the column to calculate zero value percentage on.
  1199. Must be a numeric column.
  1200. ignore_nulls: Whether to ignore null values when calculating the percentage.
  1201. If True (default), null values are excluded from both numerator and denominator.
  1202. If False, null values are included in the denominator but not the numerator.
  1203. alias_name: Optional name for the resulting column. If not provided,
  1204. defaults to "zero_pct({column_name})".
  1205. """
  1206. def __init__(
  1207. self,
  1208. on: str,
  1209. ignore_nulls: bool = True,
  1210. alias_name: Optional[str] = None,
  1211. ):
  1212. # Initialize with a list accumulator [zero_count, non_null_count]
  1213. super().__init__(
  1214. alias_name if alias_name else f"zero_pct({str(on)})",
  1215. on=on,
  1216. ignore_nulls=ignore_nulls,
  1217. zero_factory=lambda: [0, 0],
  1218. )
  1219. def aggregate_block(self, block: Block) -> List[int]:
  1220. column_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
  1221. count = column_accessor.count(ignore_nulls=self._ignore_nulls)
  1222. if count == 0:
  1223. return [0, 0]
  1224. arrow_compatible = column_accessor._as_arrow_compatible()
  1225. # Use PyArrow compute to count zeros
  1226. # First create a boolean mask for zero values
  1227. zero_mask = pc.equal(arrow_compatible, 0)
  1228. # Sum the boolean mask to get count of True values (zeros)
  1229. zero_count = pc.sum(zero_mask).as_py() or 0
  1230. return [zero_count, count]
  1231. def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
  1232. return [
  1233. current_accumulator[0] + new[0], # Sum zero counts
  1234. current_accumulator[1] + new[1], # Sum non-null counts
  1235. ]
  1236. def finalize(self, accumulator: List[int]) -> Optional[float]:
  1237. if accumulator[1] == 0:
  1238. return None
  1239. return (accumulator[0] / accumulator[1]) * 100.0
  1240. @PublicAPI(stability="alpha")
  1241. class ApproximateQuantile(AggregateFnV2):
  1242. def _require_datasketches(self):
  1243. try:
  1244. from datasketches import kll_floats_sketch # type: ignore[import]
  1245. except ImportError as exc:
  1246. raise ImportError(
  1247. "ApproximateQuantile requires the `datasketches` package. "
  1248. "Install it with `pip install datasketches`."
  1249. ) from exc
  1250. return kll_floats_sketch
  1251. def __init__(
  1252. self,
  1253. on: str,
  1254. quantiles: List[float],
  1255. quantile_precision: int = 800,
  1256. alias_name: Optional[str] = None,
  1257. ):
  1258. """
  1259. Computes the approximate quantiles of a column by using a datasketches kll_floats_sketch.
  1260. https://datasketches.apache.org/docs/KLL/KLLSketch.html
  1261. The accuracy of the KLL quantile sketch is a function of the configured quantile precision, which also affects
  1262. the overall size of the sketch.
  1263. The KLL Sketch has absolute error. For example, a specified rank accuracy of 1% at the
  1264. median (rank = 0.50) means that the true quantile (if you could extract it from the set)
  1265. should be between getQuantile(0.49) and getQuantile(0.51). This same 1% error applied at a
  1266. rank of 0.95 means that the true quantile should be between getQuantile(0.94) and getQuantile(0.96).
  1267. In other words, the error is a fixed +/- epsilon for the entire range of ranks.
  1268. Typical single-sided rank error by quantile_precision (use for getQuantile/getRank):
  1269. - quantile_precision=100 → ~2.61%
  1270. - quantile_precision=200 → ~1.33%
  1271. - quantile_precision=400 → ~0.68%
  1272. - quantile_precision=800 → ~0.35%
  1273. See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
  1274. Null values in the target column are ignored when constructing the sketch.
  1275. Example:
  1276. .. testcode::
  1277. import ray
  1278. from ray.data.aggregate import ApproximateQuantile
  1279. # Create a dataset with some values
  1280. ds = ray.data.from_items(
  1281. [{"value": 20.0}, {"value": 40.0}, {"value": 60.0},
  1282. {"value": 80.0}, {"value": 100.0}]
  1283. )
  1284. result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9]))
  1285. # Result: {'approx_quantile(value)': [20.0, 60.0, 100.0]}
  1286. Args:
  1287. on: The name of the column to calculate the quantile on. Must be a numeric column.
  1288. quantiles: The list of quantiles to compute. Must be between 0 and 1 inclusive. For example, quantiles=[0.5] computes the median. Null entries in the source column are skipped.
  1289. quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL); higher values yield lower error but use more memory. Defaults to 800. See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
  1290. alias_name: Optional name for the resulting column. If not provided, defaults to "approx_quantile({column_name})".
  1291. """
  1292. self._sketch_cls = self._require_datasketches()
  1293. self._quantiles = quantiles
  1294. self._quantile_precision = quantile_precision
  1295. super().__init__(
  1296. alias_name if alias_name else f"approx_quantile({str(on)})",
  1297. on=on,
  1298. ignore_nulls=True,
  1299. zero_factory=lambda: self.zero(quantile_precision).serialize(),
  1300. )
  1301. def zero(self, quantile_precision: int):
  1302. return self._sketch_cls(k=quantile_precision)
  1303. def aggregate_block(self, block: Block) -> bytes:
  1304. block_acc = BlockAccessor.for_block(block)
  1305. table = block_acc.to_arrow()
  1306. column = table.column(self.get_target_column())
  1307. sketch = self.zero(self._quantile_precision)
  1308. for value in column:
  1309. # we ignore nulls here
  1310. if value.as_py() is not None:
  1311. sketch.update(float(value.as_py()))
  1312. return sketch.serialize()
  1313. def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
  1314. combined = self.zero(self._quantile_precision)
  1315. combined.merge(self._sketch_cls.deserialize(current_accumulator))
  1316. combined.merge(self._sketch_cls.deserialize(new))
  1317. return combined.serialize()
  1318. def finalize(self, accumulator: bytes) -> List[float]:
  1319. return self._sketch_cls.deserialize(accumulator).get_quantiles(self._quantiles)
  1320. @PublicAPI(stability="alpha")
  1321. class ApproximateTopK(AggregateFnV2):
  1322. def _require_datasketches(self):
  1323. try:
  1324. from datasketches import frequent_strings_sketch
  1325. except ImportError as exc:
  1326. raise ImportError(
  1327. "ApproximateTopK requires the `datasketches` package. "
  1328. "Install it with `pip install datasketches`."
  1329. ) from exc
  1330. return frequent_strings_sketch
  1331. def __init__(
  1332. self,
  1333. on: str,
  1334. k: int,
  1335. log_capacity: int = 15,
  1336. alias_name: Optional[str] = None,
  1337. encode_lists: bool = False,
  1338. ):
  1339. """
  1340. Computes the approximate top k items in a column by using a datasketches frequent_strings_sketch.
  1341. https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html
  1342. Guarantees:
  1343. - Any item with true frequency > N / (2^log_capacity) is guaranteed to appear in the results
  1344. - Reported counts may have an error of at most ± N / (2^log_capacity).
  1345. If log_capacity is too small for your data:
  1346. - Low-frequency items may be evicted from the sketch, potentially causing the top-k
  1347. results to miss items that should appear in the output.
  1348. - The error bounds increase, reducing the accuracy of the reported counts.
  1349. Example:
  1350. .. testcode::
  1351. import ray
  1352. from ray.data.aggregate import ApproximateTopK
  1353. ds = ray.data.from_items([
  1354. {"word": "apple"}, {"word": "banana"}, {"word": "apple"},
  1355. {"word": "cherry"}, {"word": "apple"}
  1356. ])
  1357. result = ds.aggregate(ApproximateTopK(on="word", k=2))
  1358. # Result: {'approx_topk(word)': [{'word': 'apple', 'count': 3}, {'word': 'banana', 'count': 1}]}
  1359. Args:
  1360. on: The name of the column to aggregate.
  1361. k: The number of top items to return.
  1362. log_capacity: Base 2 logarithm of the maximum size of the internal hash map.
  1363. Higher values increase accuracy but use more memory. Defaults to 15.
  1364. alias_name: The name of the aggregate. Defaults to None.
  1365. encode_lists: If `True`, encode list elements. If `False`, encode
  1366. whole lists (i.e., the entire list is considered as a single object).
  1367. `False` by default. Note that this is a top-level flatten (not a recursive
  1368. flatten) operation.
  1369. """
  1370. self.k = k
  1371. self._log_capacity = log_capacity
  1372. self._frequent_strings_sketch = self._require_datasketches()
  1373. self._encode_lists = encode_lists
  1374. super().__init__(
  1375. alias_name if alias_name else f"approx_topk({str(on)})",
  1376. on=on,
  1377. ignore_nulls=True,
  1378. zero_factory=lambda: self.zero(log_capacity).serialize(),
  1379. )
  1380. def zero(self, log_capacity: int):
  1381. return self._frequent_strings_sketch(lg_max_k=log_capacity)
  1382. def aggregate_block(self, block: Block) -> bytes:
  1383. # Note: The datasketches Python bindings only expose frequent_strings_sketch
  1384. # (not type-specific variants like frequent_ints_sketch). We use pickle
  1385. # serialization as a workaround, which is less performant than native
  1386. # type-specific sketches. Revisit if type-specific bindings are added.
  1387. block_acc = BlockAccessor.for_block(block)
  1388. table = block_acc.to_arrow()
  1389. column = table.column(self.get_target_column())
  1390. sketch = self.zero(self._log_capacity)
  1391. for value in column:
  1392. py_value = value.as_py()
  1393. if self._encode_lists and isinstance(py_value, list):
  1394. for item in py_value:
  1395. if item is None:
  1396. continue
  1397. dump = pickle.dumps(item).hex()
  1398. sketch.update(dump)
  1399. elif py_value is not None:
  1400. dump = pickle.dumps(py_value).hex()
  1401. sketch.update(dump)
  1402. return sketch.serialize()
  1403. def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
  1404. combined = self.zero(self._log_capacity)
  1405. combined.merge(self._frequent_strings_sketch.deserialize(current_accumulator))
  1406. combined.merge(self._frequent_strings_sketch.deserialize(new))
  1407. return combined.serialize()
  1408. def finalize(self, accumulator: bytes) -> List[Dict[str, Any]]:
  1409. from datasketches import frequent_items_error_type
  1410. column = self.get_target_column()
  1411. frequent_items = self._frequent_strings_sketch.deserialize(
  1412. accumulator
  1413. ).get_frequent_items(frequent_items_error_type.NO_FALSE_NEGATIVES)
  1414. return [
  1415. {column: pickle.loads(bytes.fromhex(item[0])), "count": int(item[1])}
  1416. for item in frequent_items[: self.k]
  1417. ]