| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777 |
- import abc
- import enum
- import math
- import pickle
- import re
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Collection,
- Dict,
- Generic,
- List,
- Optional,
- Protocol,
- Set,
- TypeVar,
- Union,
- )
- import numpy as np
- import pyarrow.compute as pc
- from ray.data._internal.util import is_null
- from ray.data.block import (
- Block,
- BlockAccessor,
- BlockColumn,
- BlockColumnAccessor,
- KeyType,
- )
- from ray.util.annotations import Deprecated, PublicAPI
- if TYPE_CHECKING:
- from ray.data.dataset import Schema
- class _SupportsRichComparison(Protocol):
- def __lt__(self, other: Any) -> bool:
- ...
- def __le__(self, other: Any) -> bool:
- ...
- def __gt__(self, other: Any) -> bool:
- ...
- def __ge__(self, other: Any) -> bool:
- ...
- AccumulatorType = TypeVar("AccumulatorType")
- SupportsRichComparisonType = TypeVar(
- "SupportsRichComparisonType", bound=_SupportsRichComparison
- )
- AggOutputType = TypeVar("AggOutputType")
- _AGGREGATION_NAME_PATTERN = re.compile(r"^([^(]+)(?:\(.*\))?$")
- @Deprecated(message="AggregateFn is deprecated, please use AggregateFnV2")
- @PublicAPI
- class AggregateFn:
- """NOTE: THIS IS DEPRECATED, PLEASE USE :class:`AggregateFnV2` INSTEAD
- Defines how to perform a custom aggregation in Ray Data.
- `AggregateFn` instances are passed to a Dataset's ``.aggregate(...)`` method to
- specify the steps required to transform and combine rows sharing the same key.
- This enables implementing custom aggregators beyond the standard
- built-in options like Sum, Min, Max, Mean, etc.
- Args:
- init: Function that creates an initial aggregator for each group. Receives a key
- (the group key) and returns the initial accumulator state (commonly 0,
- an empty list, or an empty dictionary).
- merge: Function that merges two accumulators generated by different workers
- into one accumulator.
- name: An optional display name for the aggregator. Useful for debugging.
- accumulate_row: Function that processes an individual row. It receives the current
- accumulator and a row, then returns an updated accumulator. Cannot be
- used if `accumulate_block` is provided.
- accumulate_block: Function that processes an entire block of rows at once. It receives the
- current accumulator and a block of rows, then returns an updated accumulator.
- This allows for vectorized operations. Cannot be used if `accumulate_row`
- is provided.
- finalize: Function that finishes the aggregation by transforming the final
- accumulator state into the desired output. For example, if your
- accumulator is a list of items, you may want to compute a statistic
- from the list. If not provided, the final accumulator state is returned
- as-is.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import AggregateFn
- # A simple aggregator that counts how many rows there are per group
- count_agg = AggregateFn(
- init=lambda k: 0,
- accumulate_row=lambda counter, row: counter + 1,
- merge=lambda c1, c2: c1 + c2,
- name="custom_count"
- )
- ds = ray.data.from_items([{"group": "A"}, {"group": "B"}, {"group": "A"}])
- result = ds.groupby("group").aggregate(count_agg).take_all()
- # result: [{'group': 'A', 'custom_count': 2}, {'group': 'B', 'custom_count': 1}]
- """
- def __init__(
- self,
- init: Callable[[KeyType], AccumulatorType],
- merge: Callable[[AccumulatorType, AccumulatorType], AccumulatorType],
- name: str,
- accumulate_row: Callable[
- [AccumulatorType, Dict[str, Any]], AccumulatorType
- ] = None,
- accumulate_block: Callable[[AccumulatorType, Block], AccumulatorType] = None,
- finalize: Optional[Callable[[AccumulatorType], AggOutputType]] = None,
- ):
- if (accumulate_row is None and accumulate_block is None) or (
- accumulate_row is not None and accumulate_block is not None
- ):
- raise ValueError(
- "Exactly one of accumulate_row or accumulate_block must be provided."
- )
- if accumulate_block is None:
- def accumulate_block(a: AccumulatorType, block: Block) -> AccumulatorType:
- block_acc = BlockAccessor.for_block(block)
- for r in block_acc.iter_rows(public_row_format=False):
- a = accumulate_row(a, r)
- return a
- if not isinstance(name, str):
- raise TypeError("`name` must be provided.")
- if finalize is None:
- finalize = lambda a: a # noqa: E731
- self.name = name
- self.init = init
- self.merge = merge
- self.accumulate_block = accumulate_block
- self.finalize = finalize
- def _validate(self, schema: Optional["Schema"]) -> None:
- """Raise an error if this cannot be applied to the given schema."""
- pass
- @PublicAPI(stability="alpha")
- class AggregateFnV2(AggregateFn, abc.ABC, Generic[AccumulatorType, AggOutputType]):
- """Provides an interface to implement efficient aggregations to be applied
- to the dataset.
- `AggregateFnV2` instances are passed to a Dataset's ``.aggregate(...)`` method to
- perform distributed aggregations. To create a custom aggregation, you should subclass
- `AggregateFnV2` and implement the `aggregate_block` and `combine` methods.
- The `finalize` method can also be overridden if the final accumulated state
- needs further transformation.
- Aggregation follows these steps:
- 1. **Initialization**: For each group (if grouping) or for the entire dataset,
- an initial accumulator is created using `zero_factory`.
- 2. **Block Aggregation**: The `aggregate_block` method is applied to
- each block independently, producing a partial aggregation result for that block.
- 3. **Combination**: The `combine` method is used to merge these partial
- results (or an existing accumulated result with a new partial result)
- into a single, combined accumulator.
- 4. **Finalization**: Optionally, the `finalize` method transforms the
- final combined accumulator into the desired output format.
- Generic Type Parameters:
- This class is parameterized by two type variables:
- - ``AccumulatorType``: The type of the intermediate state (accumulator) used
- during aggregation. This is what `aggregate_block` returns, what `combine`
- takes as inputs and returns, and what `finalize` receives. For simple
- aggregations like `Sum`, this might just be a numeric type. For more complex
- aggregations like `Mean`, this could be a composite type like
- ``List[Union[int, float]]`` representing ``[sum, count]``.
- - ``AggOutputType``: The type of the final result after `finalize` is called.
- This is what gets written to the output dataset. For `Sum`, this is the
- same as the accumulator type (a number). For `Mean`, the accumulator is
- ``[sum, count]`` but the output is a single ``float`` (the computed mean).
- Examples of type parameterization in built-in aggregations::
- Count(AggregateFnV2[int, int]) # accumulator: int, output: int
- Sum(AggregateFnV2[Union[int, float], ...]) # accumulator: number, output: number
- Mean(AggregateFnV2[List[...], float]) # accumulator: [sum, count], output: float
- Std(AggregateFnV2[List[...], float]) # accumulator: [M2, mean, count], output: float
- Args:
- name: The name of the aggregation. This will be used as the column name
- in the output, e.g., "sum(my_col)".
- zero_factory: A callable that returns the initial "zero" value for the
- accumulator. For example, for a sum, this would be `lambda: 0`; for
- finding a minimum, `lambda: float("inf")`, for finding a maximum,
- `lambda: float("-inf")`.
- on: The name of the column to perform the aggregation on. If `None`,
- the aggregation is performed over the entire row (e.g., for `Count()`).
- ignore_nulls: Whether to ignore null values during aggregation.
- If `True`, nulls are skipped.
- If `False`, the presence of a null value might result in a null output,
- depending on the aggregation logic.
- """
- def __init__(
- self,
- name: str,
- zero_factory: Callable[[], AccumulatorType],
- *,
- on: Optional[str],
- ignore_nulls: bool,
- ):
- if not name:
- raise ValueError(
- f"Non-empty string has to be provided as name (got {name})"
- )
- self._target_col_name = on
- self._ignore_nulls = ignore_nulls
- # Extract and store the agg name (e.g., "sum" from "sum(col)")
- # This avoids string parsing later
- match = _AGGREGATION_NAME_PATTERN.match(name)
- if match:
- self._agg_name = match.group(1)
- else:
- self._agg_name = name
- _safe_combine = _null_safe_combine(self.combine, ignore_nulls)
- _safe_aggregate = _null_safe_aggregate(self.aggregate_block, ignore_nulls)
- _safe_finalize = _null_safe_finalize(self.finalize)
- _safe_zero_factory = _null_safe_zero_factory(zero_factory, ignore_nulls)
- super().__init__(
- name=name,
- init=_safe_zero_factory,
- merge=_safe_combine,
- accumulate_block=lambda _, block: _safe_aggregate(block),
- finalize=_safe_finalize,
- )
- def get_target_column(self) -> Optional[str]:
- return self._target_col_name
- def get_agg_name(self) -> str:
- """Return the agg name (e.g., 'sum', 'mean', 'count').
- Returns the aggregation type extracted from the name during initialization.
- For example, returns 'sum' for an aggregator named 'sum(col)'.
- """
- return self._agg_name
- @abc.abstractmethod
- def combine(
- self, current_accumulator: AccumulatorType, new: AccumulatorType
- ) -> AccumulatorType:
- """Combines a new partial aggregation result with the current accumulator.
- This method defines how two intermediate aggregation states are merged.
- For example, if `aggregate_block` produces partial sums `s1` and `s2` from
- two different blocks, `combine(s1, s2)` should return `s1 + s2`.
- Args:
- current_accumulator: The current accumulated state (e.g., the result of
- previous `combine` calls or an initial value from `zero_factory`).
- new: A new partially aggregated value, typically the output of
- `aggregate_block` from a new block of data, or another accumulator
- from a parallel task.
- Returns:
- The updated accumulator after combining it with the new value.
- """
- ...
- @abc.abstractmethod
- def aggregate_block(self, block: Block) -> AccumulatorType:
- """Aggregates data within a single block.
- This method processes all rows in a given `Block` and returns a partial
- aggregation result for that block. For instance, if implementing a sum,
- this method would sum all relevant values within the block.
- Args:
- block: A `Block` of data to be aggregated.
- Returns:
- A partial aggregation result for the input block. The type of this
- result (`AggType`) should be consistent with the `current_accumulator`
- and `new` arguments of the `combine` method, and the `accumulator`
- argument of the `finalize` method.
- """
- ...
- def finalize(self, accumulator: AccumulatorType) -> Optional[AggOutputType]:
- """Transforms the final accumulated state into the desired output.
- This method is called once per group after all blocks have been processed
- and all partial results have been combined. It provides an opportunity
- to perform a final transformation on the accumulated data.
- For many aggregations (e.g., Sum, Count, Min, Max), the accumulated state
- is already the final result, so this method can simply return the
- accumulator as is (which is the default behavior).
- For other aggregations, like Mean, this method is crucial.
- A Mean aggregation might accumulate `[sum, count]`. The `finalize`
- method would then compute `sum / count` to get the final mean.
- Args:
- accumulator: The final accumulated state for a group, after all
- `aggregate_block` and `combine` operations.
- Returns:
- The final result of the aggregation for the group.
- """
- return accumulator
- def _validate(self, schema: Optional["Schema"]) -> None:
- if self._target_col_name:
- from ray.data._internal.planner.exchange.sort_task_spec import SortKey
- SortKey(self._target_col_name).validate_schema(schema)
- @PublicAPI
- class Count(AggregateFnV2[int, int]):
- """Defines count aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Count
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Counting all rows:
- result = ds.aggregate(Count())
- # result: {'count()': 100}
- # Counting all rows per group:
- result = ds.groupby("group_key").aggregate(Count(on="id")).take_all()
- # result: [{'group_key': 0, 'count(id)': 34},
- # {'group_key': 1, 'count(id)': 33},
- # {'group_key': 2, 'count(id)': 33}]
- Args:
- on: Optional name of the column to count values on. If None, counts rows.
- ignore_nulls: Whether to ignore null values when counting. Only applies if
- `on` is specified. Default is `False` which means `Count()` on a column
- will count nulls by default. To match pandas default behavior of not counting nulls,
- set `ignore_nulls=True`.
- alias_name: Optional name for the resulting column.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = False,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"count({on or ''})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=lambda: 0,
- )
- def aggregate_block(self, block: Block) -> int:
- block_accessor = BlockAccessor.for_block(block)
- if self._target_col_name is None:
- # In case of global count, simply fetch number of rows
- return block_accessor.num_rows()
- return block_accessor.count(
- self._target_col_name, ignore_nulls=self._ignore_nulls
- )
- def combine(self, current_accumulator: int, new: int) -> int:
- return current_accumulator + new
- @PublicAPI
- class AsList(AggregateFnV2[List, List]):
- """Listing aggregation combining all values within the group into a single
- list element.
- Example:
- .. testcode::
- :skipif: True
- # Skip testing b/c this example require proper ordering of the output
- # to be robust and not flaky
- import ray
- from ray.data.aggregate import AsList
- ds = ray.data.range(10)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Listing all elements per group:
- result = ds.groupby("group_key").aggregate(AsList(on="id")).take_all()
- # result: [{'group_key': 0, 'list(id)': [0, 3, 6, 9]},
- # {'group_key': 1, 'list(id)': [1, 4, 7]},
- # {'group_key': 2, 'list(id)': [2, 5, 8]}
- Args:
- on: The name of the column to collect values from. Must be provided.
- alias_name: Optional name for the resulting column.
- ignore_nulls: Whether to ignore null values when collecting. If `True`,
- nulls are skipped. If `False` (default), nulls are included in the list.
- """
- def __init__(
- self,
- on: str,
- alias_name: Optional[str] = None,
- ignore_nulls: bool = False,
- ):
- super().__init__(
- alias_name if alias_name else f"list({on or ''})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=lambda: [],
- )
- def aggregate_block(self, block: Block) -> AccumulatorType:
- column_accessor = BlockColumnAccessor.for_column(
- block[self.get_target_column()]
- )
- if self._ignore_nulls:
- column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna())
- return column_accessor.to_pylist()
- def combine(
- self, current_accumulator: AccumulatorType, new: AccumulatorType
- ) -> AccumulatorType:
- return current_accumulator + new
- @PublicAPI
- class Sum(AggregateFnV2[Union[int, float], Union[int, float]]):
- """Defines sum aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Sum
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Summing all rows per group:
- result = ds.aggregate(Sum(on="id"))
- # result: {'sum(id)': 4950}
- Args:
- on: The name of the numerical column to sum. Must be provided.
- ignore_nulls: Whether to ignore null values during summation. If `True` (default),
- nulls are skipped. If `False`, the sum will be null if any
- value in the group is null.
- alias_name: Optional name for the resulting column.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"sum({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=lambda: 0,
- )
- def aggregate_block(self, block: Block) -> Union[int, float]:
- return BlockAccessor.for_block(block).sum(
- self._target_col_name, self._ignore_nulls
- )
- def combine(
- self, current_accumulator: Union[int, float], new: Union[int, float]
- ) -> Union[int, float]:
- return current_accumulator + new
- @PublicAPI
- class Min(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
- """Defines min aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Min
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Finding the minimum value per group:
- result = ds.groupby("group_key").aggregate(Min(on="id")).take_all()
- # result: [{'group_key': 0, 'min(id)': 0},
- # {'group_key': 1, 'min(id)': 1},
- # {'group_key': 2, 'min(id)': 2}]
- Args:
- on: The name of the column to find the minimum value from. Must be provided.
- ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
- skipped. If `False`, the minimum will be null if any value in
- the group is null (for most data types, or follow type-specific
- comparison rules with nulls).
- alias_name: Optional name for the resulting column.
- zero_factory: A callable that returns the initial "zero" value for the
- accumulator. For example, for a float column, this would be
- `lambda: float("+inf")`. Default is `lambda: float("+inf")`.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- zero_factory: Callable[[], SupportsRichComparisonType] = lambda: float("+inf"),
- ):
- super().__init__(
- alias_name if alias_name else f"min({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=zero_factory,
- )
- def aggregate_block(self, block: Block) -> SupportsRichComparisonType:
- return BlockAccessor.for_block(block).min(
- self._target_col_name, self._ignore_nulls
- )
- def combine(
- self,
- current_accumulator: SupportsRichComparisonType,
- new: SupportsRichComparisonType,
- ) -> SupportsRichComparisonType:
- return min(current_accumulator, new)
- @PublicAPI
- class Max(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
- """Defines max aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Max
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Finding the maximum value per group:
- result = ds.groupby("group_key").aggregate(Max(on="id")).take_all()
- # result: [{'group_key': 0, 'max(id)': ...},
- # {'group_key': 1, 'max(id)': ...},
- # {'group_key': 2, 'max(id)': ...}]
- Args:
- on: The name of the column to find the maximum value from. Must be provided.
- ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
- skipped. If `False`, the maximum will be null if any value in
- the group is null (for most data types, or follow type-specific
- comparison rules with nulls).
- alias_name: Optional name for the resulting column.
- zero_factory: A callable that returns the initial "zero" value for the
- accumulator. For example, for a float column, this would be
- `lambda: float("-inf")`. Default is `lambda: float("-inf")`.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- zero_factory: Callable[[], SupportsRichComparisonType] = lambda: float("-inf"),
- ):
- super().__init__(
- alias_name if alias_name else f"max({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=zero_factory,
- )
- def aggregate_block(self, block: Block) -> SupportsRichComparisonType:
- return BlockAccessor.for_block(block).max(
- self._target_col_name, self._ignore_nulls
- )
- def combine(
- self,
- current_accumulator: SupportsRichComparisonType,
- new: SupportsRichComparisonType,
- ) -> SupportsRichComparisonType:
- return max(current_accumulator, new)
- @PublicAPI
- class Mean(AggregateFnV2[List[Union[int, float]], float]):
- """Defines mean (average) aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Mean
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Calculating the mean value per group:
- result = ds.groupby("group_key").aggregate(Mean(on="id")).take_all()
- # result: [{'group_key': 0, 'mean(id)': ...},
- # {'group_key': 1, 'mean(id)': ...},
- # {'group_key': 2, 'mean(id)': ...}]
- Args:
- on: The name of the numerical column to calculate the mean on. Must be provided.
- ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
- skipped. If `False`, the mean will be null if any value in the
- group is null.
- alias_name: Optional name for the resulting column.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"mean({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- # The accumulator is: [current_sum, current_count].
- # NOTE: We copy the returned list `list([0,0])` as some internal mechanisms
- # might modify accumulators in-place.
- zero_factory=lambda: list([0, 0]), # noqa: C410
- )
- def aggregate_block(self, block: Block) -> Optional[List[Union[int, float]]]:
- block_acc = BlockAccessor.for_block(block)
- count = block_acc.count(self._target_col_name, self._ignore_nulls)
- if count == 0 or count is None:
- # Empty or all null.
- return None
- sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
- if is_null(sum_):
- # In case of ignore_nulls=False and column containing 'null'
- # return as is (to prevent unnecessary type conversions, when, for ex,
- # using Pandas and returning None)
- return sum_
- return [sum_, count]
- def combine(
- self, current_accumulator: List[Union[int, float]], new: List[Union[int, float]]
- ) -> List[Union[int, float]]:
- return [current_accumulator[0] + new[0], current_accumulator[1] + new[1]]
- def finalize(self, accumulator: List[Union[int, float]]) -> Optional[float]:
- # The final accumulator for a group is [total_sum, total_count].
- if accumulator[1] == 0:
- # If total_count is 0 (e.g., group was empty or all nulls ignored),
- # the mean is undefined. Return NaN
- return np.nan
- return accumulator[0] / accumulator[1]
- @PublicAPI
- class Std(AggregateFnV2[List[Union[int, float]], float]):
- """Defines standard deviation aggregation.
- Uses Welford's online algorithm for numerical stability. This method computes
- the standard deviation in a single pass. Results may differ slightly from
- libraries like NumPy or Pandas that use a two-pass algorithm but are generally
- more accurate.
- See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Std
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Calculating the standard deviation per group:
- result = ds.groupby("group_key").aggregate(Std(on="id")).take_all()
- # result: [{'group_key': 0, 'std(id)': ...},
- # {'group_key': 1, 'std(id)': ...},
- # {'group_key': 2, 'std(id)': ...}]
- Args:
- on: The name of the column to calculate standard deviation on.
- ddof: Delta Degrees of Freedom. The divisor used in calculations is `N - ddof`,
- where `N` is the number of elements. Default is 1.
- ignore_nulls: Whether to ignore null values. Default is True.
- alias_name: Optional name for the resulting column.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ddof: int = 1,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"std({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- # Accumulator: [M2, mean, count]
- # M2: sum of squares of differences from the current mean
- # mean: current mean
- # count: current count of non-null elements
- # We need to copy the list as it might be modified in-place by some aggregations.
- zero_factory=lambda: list([0, 0, 0]), # noqa: C410
- )
- self._ddof = ddof
- def aggregate_block(self, block: Block) -> List[Union[int, float]]:
- block_acc = BlockAccessor.for_block(block)
- count = block_acc.count(self._target_col_name, ignore_nulls=self._ignore_nulls)
- if count == 0 or count is None:
- # Empty or all null.
- return None
- sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
- if is_null(sum_):
- # If sum is null (e.g., ignore_nulls=False and a null was encountered),
- # return as is to prevent type conversions.
- return sum_
- mean = sum_ / count
- M2 = block_acc.sum_of_squared_diffs_from_mean(
- self._target_col_name, self._ignore_nulls, mean
- )
- return [M2, mean, count]
- def combine(
- self, current_accumulator: List[float], new: List[float]
- ) -> List[float]:
- # Merges two accumulators [M2, mean, count] using a parallel algorithm.
- # See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
- M2_a, mean_a, count_a = current_accumulator
- M2_b, mean_b, count_b = new
- delta = mean_b - mean_a
- count = count_a + count_b
- # NOTE: We use this mean calculation since it's more numerically
- # stable than mean_a + delta * count_b / count, which actually
- # deviates from Pandas in the ~15th decimal place and causes our
- # exact comparison tests to fail.
- mean = (mean_a * count_a + mean_b * count_b) / count
- # Update the sum of squared differences.
- M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
- return [M2, mean, count]
- def finalize(self, accumulator: List[float]) -> Optional[float]:
- # Compute the final standard deviation from the accumulated
- # sum of squared differences from current mean and the count.
- # Final accumulator: [M2, mean, count]
- M2, mean, count = accumulator
- # Denominator for variance calculation is count - ddof
- if count - self._ddof <= 0:
- # If count - ddof is not positive, variance/std is undefined (or zero).
- # Return NaN, consistent with pandas/numpy.
- return np.nan
- # Standard deviation is the square root of variance (M2 / (count - ddof))
- return math.sqrt(M2 / (count - self._ddof))
- @PublicAPI
- class AbsMax(AggregateFnV2[SupportsRichComparisonType, SupportsRichComparisonType]):
- """Defines absolute max aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import AbsMax
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Calculating the absolute maximum value per group:
- result = ds.groupby("group_key").aggregate(AbsMax(on="id")).take_all()
- # result: [{'group_key': 0, 'abs_max(id)': ...},
- # {'group_key': 1, 'abs_max(id)': ...},
- # {'group_key': 2, 'abs_max(id)': ...}]
- Args:
- on: The name of the column to calculate absolute maximum on. Must be provided.
- ignore_nulls: Whether to ignore null values. Default is True.
- alias_name: Optional name for the resulting column.
- zero_factory: A callable that returns the initial "zero" value for the
- accumulator. For example, for a float column, this would be
- `lambda: 0`. Default is `lambda: 0`.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- zero_factory: Callable[[], SupportsRichComparisonType] = lambda: 0,
- ):
- if on is None or not isinstance(on, str):
- raise ValueError(f"Column to aggregate on has to be provided (got {on})")
- super().__init__(
- alias_name if alias_name else f"abs_max({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=zero_factory,
- )
- def aggregate_block(self, block: Block) -> Optional[SupportsRichComparisonType]:
- block_accessor = BlockAccessor.for_block(block)
- max_ = block_accessor.max(self._target_col_name, self._ignore_nulls)
- min_ = block_accessor.min(self._target_col_name, self._ignore_nulls)
- if is_null(max_) or is_null(min_):
- return None
- return max(abs(max_), abs(min_))
- def combine(
- self,
- current_accumulator: SupportsRichComparisonType,
- new: SupportsRichComparisonType,
- ) -> SupportsRichComparisonType:
- return max(current_accumulator, new)
- @PublicAPI
- class Quantile(AggregateFnV2[List[Any], List[Any]]):
- """Defines Quantile aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Quantile
- ds = ray.data.range(100)
- # Schema: {'id': int64}
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Schema: {'id': int64, 'group_key': int64}
- # Calculating the 50th percentile (median) per group:
- result = ds.groupby("group_key").aggregate(Quantile(q=0.5, on="id")).take_all()
- # result: [{'group_key': 0, 'quantile(id)': ...},
- # {'group_key': 1, 'quantile(id)': ...},
- # {'group_key': 2, 'quantile(id)': ...}]
- Args:
- on: The name of the column to calculate the quantile on. Must be provided.
- q: The quantile to compute, which must be between 0 and 1 inclusive.
- For example, q=0.5 computes the median.
- ignore_nulls: Whether to ignore null values. Default is True.
- alias_name: Optional name for the resulting column.
- """
- def __init__(
- self,
- on: Optional[str] = None,
- q: float = 0.5,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- self._q = q
- super().__init__(
- alias_name if alias_name else f"quantile({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=list,
- )
- def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
- if isinstance(current_accumulator, List) and isinstance(new, List):
- current_accumulator.extend(new)
- return current_accumulator
- if isinstance(current_accumulator, List) and (not isinstance(new, List)):
- if new is not None and new != "":
- current_accumulator.append(new)
- return current_accumulator
- if isinstance(new, List) and (not isinstance(current_accumulator, List)):
- if current_accumulator is not None and current_accumulator != "":
- new.append(current_accumulator)
- return new
- ls = []
- if current_accumulator is not None and current_accumulator != "":
- ls.append(current_accumulator)
- if new is not None and new != "":
- ls.append(new)
- return ls
- def aggregate_block(self, block: Block) -> List[Any]:
- block_acc = BlockAccessor.for_block(block)
- ls = []
- for row in block_acc.iter_rows(public_row_format=False):
- ls.append(row.get(self._target_col_name))
- return ls
- def finalize(self, accumulator: List[Any]) -> Optional[Any]:
- if self._ignore_nulls:
- accumulator = [v for v in accumulator if not is_null(v)]
- else:
- nulls = [v for v in accumulator if is_null(v)]
- if len(nulls) > 0:
- # If nulls are present and not ignored, the quantile is undefined.
- # Return the first null encountered to preserve column type.
- return nulls[0]
- if not accumulator:
- # If the list is empty (e.g., all values were null and ignored, or no values),
- # quantile is undefined.
- return None
- key = lambda x: x # noqa: E731
- input_values = sorted(accumulator)
- k = (len(input_values) - 1) * self._q
- f = math.floor(k)
- c = math.ceil(k)
- if f == c:
- return key(input_values[int(k)])
- # Interpolate between the elements at floor and ceil indices.
- d0 = key(input_values[int(f)]) * (c - k)
- d1 = key(input_values[int(c)]) * (k - f)
- return round(d0 + d1, 5)
- @PublicAPI
- class Unique(AggregateFnV2[Set[Any], List[Any]]):
- """Defines unique aggregation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import Unique
- ds = ray.data.range(100)
- ds = ds.add_column("group_key", lambda x: x % 3)
- # Calculating the unique values per group:
- result = ds.groupby("group_key").aggregate(Unique(on="id")).take_all()
- # result: [{'group_key': 0, 'unique(id)': ...},
- # {'group_key': 1, 'unique(id)': ...},
- # {'group_key': 2, 'unique(id)': ...}]
- Args:
- on: The name of the column from which to collect unique values.
- ignore_nulls: Whether to ignore null values when collecting unique items.
- Default is True (nulls are excluded).
- alias_name: Optional name for the resulting column.
- encode_lists: If `True`, encode list elements. If `False`, encode
- whole lists (i.e., the entire list is considered as a single object).
- `False` by default. Note that this is a top-level flatten (not a recursive
- flatten) operation.
- """
- class ListEncodingMode(str, enum.Enum):
- """Controls how to encode individual elements inside the list column:
- - NONE: no encoding applied, elements (lists) are stored as is and
- unique ones are returned.
- - FLATTEN: column of element lists is flattened into a single list.
- - HASH: each list element is hashed, a list of unique hashes is returned.
- """
- FLATTEN = "FLATTEN"
- HASH = "HASH"
- def __init__(
- self,
- on: Optional[str] = None,
- ignore_nulls: bool = False,
- alias_name: Optional[str] = None,
- encode_lists: Union[bool, ListEncodingMode, None] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"unique({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=set,
- )
- if isinstance(encode_lists, Unique.ListEncodingMode):
- self._list_encoding_mode = encode_lists
- elif isinstance(encode_lists, bool) and encode_lists:
- self._list_encoding_mode = Unique.ListEncodingMode.FLATTEN
- else:
- self._list_encoding_mode = None
- def combine(self, current_accumulator: Set[Any], new: Set[Any]) -> Set[Any]:
- return self._to_set(current_accumulator) | self._to_set(new)
- def _compute_unique(self, block: Block) -> BlockColumn:
- column = block[self._target_col_name]
- column_accessor = BlockColumnAccessor.for_column(column)
- if (
- column_accessor.is_composed_of_lists()
- and self._list_encoding_mode is not None
- ):
- if self._list_encoding_mode == Unique.ListEncodingMode.FLATTEN:
- column_accessor = BlockColumnAccessor.for_column(
- column_accessor.flatten()
- )
- elif self._list_encoding_mode == Unique.ListEncodingMode.HASH:
- column_accessor = BlockColumnAccessor.for_column(column_accessor.hash())
- else:
- raise ValueError(
- f"list encoding mode not supported: {self._list_encoding_mode}"
- )
- if self._ignore_nulls:
- column_accessor = BlockColumnAccessor.for_column(column_accessor.dropna())
- return column_accessor.unique()
- def aggregate_block(self, block: Block) -> List[Any]:
- column = self._compute_unique(block)
- return BlockColumnAccessor.for_column(column).to_pylist()
- @staticmethod
- def _to_set(x):
- if isinstance(x, set):
- return Unique._normalize_nans(x)
- elif isinstance(x, list):
- if len(x) > 0 and isinstance(x[0], list):
- # necessary because pyarrow converts all tuples to
- # list internally.
- x = map(lambda v: None if v is None else tuple(v), x)
- return Unique._normalize_nans(x)
- else:
- return {x}
- @staticmethod
- def _normalize_nans(x: Collection) -> Set:
- # NOTE: Pandas when converting to Python objects instantiates
- # new `float('nan')` objects which are incomparable b/w each
- # other. Here we canonicalize any nan instances replacing them
- # w/ `np.nan`
- return {v if not (isinstance(v, float) and np.isnan(v)) else np.nan for v in x}
- @PublicAPI
- class CountDistinct(Unique):
- """Defines distinct count aggregation.
- This aggregation computes the count of distinct values in a column.
- It is similar to SQL's COUNT(DISTINCT column_name) operation.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import CountDistinct
- # Create a dataset with repeated values
- ds = ray.data.from_items([
- {"category": "A"}, {"category": "B"}, {"category": "A"},
- {"category": "C"}, {"category": "A"}, {"category": "B"}
- ])
- # Count distinct categories
- result = ds.aggregate(CountDistinct(on="category"))
- # result: {'count_distinct(category)': 3}
- # Using with groupby
- ds = ray.data.from_items([
- {"group": "X", "category": "A"}, {"group": "X", "category": "B"},
- {"group": "Y", "category": "A"}, {"group": "Y", "category": "A"}
- ])
- result = ds.groupby("group").aggregate(CountDistinct(on="category")).take_all()
- # result: [{'group': 'X', 'count_distinct(category)': 2},
- # {'group': 'Y', 'count_distinct(category)': 1}]
- Args:
- on: The name of the column to count distinct values on.
- ignore_nulls: Whether to ignore null values when counting distinct items.
- Default is True (nulls are excluded from the count).
- alias_name: Optional name for the resulting column. If not provided,
- defaults to "count_distinct({on})".
- """
- def __init__(
- self,
- on: str,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- on=on,
- ignore_nulls=ignore_nulls,
- alias_name=alias_name if alias_name else f"count_distinct({str(on)})",
- )
- def finalize(self, accumulator: Set[Any]) -> int:
- """Return the count of distinct values."""
- return len(accumulator)
- @PublicAPI
- class ValueCounter(AggregateFnV2):
- """Counts the number of times each value appears in a column.
- This aggregation computes value counts for a specified column, similar to pandas'
- `value_counts()` method. It returns a dictionary with two lists: "values" containing
- the unique values found in the column, and "counts" containing the corresponding
- count for each value.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import ValueCounter
- # Create a dataset with repeated values
- ds = ray.data.from_items([
- {"category": "A"}, {"category": "B"}, {"category": "A"},
- {"category": "C"}, {"category": "A"}, {"category": "B"}
- ])
- # Count occurrences of each category
- result = ds.aggregate(ValueCounter(on="category"))
- # result: {'value_counter(category)': {'values': ['A', 'B', 'C'], 'counts': [3, 2, 1]}}
- # Using with groupby
- ds = ray.data.from_items([
- {"group": "X", "category": "A"}, {"group": "X", "category": "B"},
- {"group": "Y", "category": "A"}, {"group": "Y", "category": "A"}
- ])
- result = ds.groupby("group").aggregate(ValueCounter(on="category")).take_all()
- # result: [{'group': 'X', 'value_counter(category)': {'values': ['A', 'B'], 'counts': [1, 1]}},
- # {'group': 'Y', 'value_counter(category)': {'values': ['A'], 'counts': [2]}}]
- Args:
- on: The name of the column to count values in. Must be provided.
- alias_name: Optional name for the resulting column. If not provided,
- defaults to "value_counter({column_name})".
- """
- def __init__(
- self,
- on: str,
- alias_name: Optional[str] = None,
- ):
- super().__init__(
- alias_name if alias_name else f"value_counter({str(on)})",
- on=on,
- ignore_nulls=True,
- zero_factory=lambda: {"values": [], "counts": []},
- )
- def aggregate_block(self, block: Block) -> Dict[str, List]:
- col_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
- return col_accessor.value_counts()
- def combine(
- self,
- current_accumulator: Dict[str, List],
- new_accumulator: Dict[str, List],
- ) -> Dict[str, List]:
- values = current_accumulator["values"]
- counts = current_accumulator["counts"]
- # Build a value → index map once (avoid repeated lookups)
- value_to_index = {v: i for i, v in enumerate(values)}
- for v_new, c_new in zip(new_accumulator["values"], new_accumulator["counts"]):
- if v_new in value_to_index:
- idx = value_to_index[v_new]
- counts[idx] += c_new
- else:
- value_to_index[v_new] = len(values)
- values.append(v_new)
- counts.append(c_new)
- return current_accumulator
- def _null_safe_zero_factory(zero_factory, ignore_nulls: bool):
- """NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
- Null-safe zero factory is crucial for implementing proper aggregation
- protocol (monoid) w/o the need for additional containers.
- Main hurdle for implementing proper aggregation semantic is to be able to encode
- semantic of an "empty accumulator" and be able to tell it from the case when
- accumulator is actually holding null value:
- - Empty container can be overridden with any value
- - Container holding null can't be overridden if ignore_nulls=False
- However, it's possible for us to exploit asymmetry in cases of ignore_nulls being
- True or False:
- - Case of ignore_nulls=False entails that if there's any "null" in the sequence,
- aggregation is undefined and correspondingly expected to return null
- - Case of ignore_nulls=True in turn, entails that if aggregation returns "null"
- if and only if the sequence does NOT have any non-null value
- Therefore, we apply this difference in semantic to zero-factory to make sure that
- our aggregation protocol is adherent to that definition:
- - If ignore_nulls=True, zero-factory returns null, therefore encoding empty
- container
- - If ignore_nulls=False, couldn't return null as aggregation will incorrectly
- prioritize it, and instead it returns true zero value for the aggregation
- (ie 0 for count/sum, -inf for max, etc).
- """
- if ignore_nulls:
- def _safe_zero_factory(_):
- return None
- else:
- def _safe_zero_factory(_):
- return zero_factory()
- return _safe_zero_factory
- def _null_safe_aggregate(
- aggregate: Callable[[Block], AccumulatorType],
- ignore_nulls: bool,
- ) -> Callable[[Block], Optional[AccumulatorType]]:
- def _safe_aggregate(block: Block) -> Optional[AccumulatorType]:
- result = aggregate(block)
- # NOTE: If `ignore_nulls=True`, aggregation will only be returning
- # null if the block does NOT contain any non-null elements
- if is_null(result) and ignore_nulls:
- return None
- return result
- return _safe_aggregate
- def _null_safe_finalize(
- finalize: Callable[[AccumulatorType], AccumulatorType],
- ) -> Callable[[Optional[AccumulatorType]], AccumulatorType]:
- def _safe_finalize(acc: Optional[AccumulatorType]) -> AccumulatorType:
- # If accumulator container is not null, finalize.
- # Otherwise, return as is.
- return acc if is_null(acc) else finalize(acc)
- return _safe_finalize
- def _null_safe_combine(
- combine: Callable[[AccumulatorType, AccumulatorType], AccumulatorType],
- ignore_nulls: bool,
- ) -> Callable[
- [Optional[AccumulatorType], Optional[AccumulatorType]], Optional[AccumulatorType]
- ]:
- """Null-safe combination have to be an associative operation
- with an identity element (zero) or in other words implement a monoid.
- To achieve that in the presence of null values following semantic is
- established:
- - Case of ignore_nulls=True:
- - If current accumulator is null (ie empty), return new accumulator
- - If new accumulator is null (ie empty), return cur
- - Otherwise combine (current and new)
- - Case of ignore_nulls=False:
- - If new accumulator is null (ie has null in the sequence, b/c we're
- NOT ignoring nulls), return it
- - If current accumulator is null (ie had null in the prior sequence,
- b/c we're NOT ignoring nulls), return it
- - Otherwise combine (current and new)
- """
- if ignore_nulls:
- def _safe_combine(
- cur: Optional[AccumulatorType], new: Optional[AccumulatorType]
- ) -> Optional[AccumulatorType]:
- if is_null(cur):
- return new
- elif is_null(new):
- return cur
- else:
- return combine(cur, new)
- else:
- def _safe_combine(
- cur: Optional[AccumulatorType], new: Optional[AccumulatorType]
- ) -> Optional[AccumulatorType]:
- if is_null(new):
- return new
- elif is_null(cur):
- return cur
- else:
- return combine(cur, new)
- return _safe_combine
- @PublicAPI(stability="alpha")
- class MissingValuePercentage(AggregateFnV2[List[int], float]):
- """Calculates the percentage of null values in a column.
- This aggregation computes the percentage of null (missing) values in a dataset column.
- It treats both None values and NaN values as null. The result is a percentage value
- between 0.0 and 100.0, where 0.0 means no missing values and 100.0 means all values
- are missing.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import MissingValuePercentage
- # Create a dataset with some missing values
- ds = ray.data.from_items([
- {"value": 1}, {"value": None}, {"value": 3},
- {"value": None}, {"value": 5}
- ])
- # Calculate missing value percentage
- result = ds.aggregate(MissingValuePercentage(on="value"))
- # result: 40.0 (2 out of 5 values are missing)
- # Using with groupby
- ds = ray.data.from_items([
- {"group": "A", "value": 1}, {"group": "A", "value": None},
- {"group": "B", "value": 3}, {"group": "B", "value": None}
- ])
- result = ds.groupby("group").aggregate(MissingValuePercentage(on="value")).take_all()
- # result: [{'group': 'A', 'missing_pct(value)': 50.0},
- # {'group': 'B', 'missing_pct(value)': 50.0}]
- Args:
- on: The name of the column to calculate missing value percentage on.
- alias_name: Optional name for the resulting column. If not provided,
- defaults to "missing_pct({column_name})".
- """
- def __init__(
- self,
- on: str,
- alias_name: Optional[str] = None,
- ):
- # Initialize with a list accumulator [null_count, total_count]
- super().__init__(
- alias_name if alias_name else f"missing_pct({str(on)})",
- on=on,
- ignore_nulls=False, # Include nulls for this calculation
- zero_factory=lambda: [0, 0], # Our AggType is a simple list
- )
- def aggregate_block(self, block: Block) -> List[int]:
- column_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
- total_count = column_accessor.count(ignore_nulls=False)
- null_count = pc.sum(
- pc.is_null(column_accessor._as_arrow_compatible(), nan_is_null=True)
- ).as_py()
- # Return our accumulator
- return [null_count, total_count]
- def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
- # Merge two accumulators by summing their components
- assert len(current_accumulator) == len(new) == 2
- return [
- current_accumulator[0] + new[0], # Sum null counts
- current_accumulator[1] + new[1], # Sum total counts
- ]
- def finalize(self, accumulator: List[int]) -> Optional[float]:
- # Calculate the final percentage
- if accumulator[1] == 0:
- return None
- return (accumulator[0] / accumulator[1]) * 100.0
- @PublicAPI(stability="alpha")
- class ZeroPercentage(AggregateFnV2[List[int], float]):
- """Calculates the percentage of zero values in a numeric column.
- This aggregation computes the percentage of zero values in a numeric dataset column.
- It can optionally ignore null values when calculating the percentage. The result is
- a percentage value between 0.0 and 100.0, where 0.0 means no zero values and 100.0
- means all non-null values are zero.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import ZeroPercentage
- # Create a dataset with some zero values
- ds = ray.data.from_items([
- {"value": 0}, {"value": 1}, {"value": 0},
- {"value": 3}, {"value": 0}
- ])
- # Calculate zero value percentage
- result = ds.aggregate(ZeroPercentage(on="value"))
- # result: 60.0 (3 out of 5 values are zero)
- # With null values and ignore_nulls=True (default)
- ds = ray.data.from_items([
- {"value": 0}, {"value": None}, {"value": 0},
- {"value": 3}, {"value": 0}
- ])
- result = ds.aggregate(ZeroPercentage(on="value", ignore_nulls=True))
- # result: 75.0 (3 out of 4 non-null values are zero)
- # Using with groupby
- ds = ray.data.from_items([
- {"group": "A", "value": 0}, {"group": "A", "value": 1},
- {"group": "B", "value": 0}, {"group": "B", "value": 0}
- ])
- result = ds.groupby("group").aggregate(ZeroPercentage(on="value")).take_all()
- # result: [{'group': 'A', 'zero_pct(value)': 50.0},
- # {'group': 'B', 'zero_pct(value)': 100.0}]
- Args:
- on: The name of the column to calculate zero value percentage on.
- Must be a numeric column.
- ignore_nulls: Whether to ignore null values when calculating the percentage.
- If True (default), null values are excluded from both numerator and denominator.
- If False, null values are included in the denominator but not the numerator.
- alias_name: Optional name for the resulting column. If not provided,
- defaults to "zero_pct({column_name})".
- """
- def __init__(
- self,
- on: str,
- ignore_nulls: bool = True,
- alias_name: Optional[str] = None,
- ):
- # Initialize with a list accumulator [zero_count, non_null_count]
- super().__init__(
- alias_name if alias_name else f"zero_pct({str(on)})",
- on=on,
- ignore_nulls=ignore_nulls,
- zero_factory=lambda: [0, 0],
- )
- def aggregate_block(self, block: Block) -> List[int]:
- column_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
- count = column_accessor.count(ignore_nulls=self._ignore_nulls)
- if count == 0:
- return [0, 0]
- arrow_compatible = column_accessor._as_arrow_compatible()
- # Use PyArrow compute to count zeros
- # First create a boolean mask for zero values
- zero_mask = pc.equal(arrow_compatible, 0)
- # Sum the boolean mask to get count of True values (zeros)
- zero_count = pc.sum(zero_mask).as_py() or 0
- return [zero_count, count]
- def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
- return [
- current_accumulator[0] + new[0], # Sum zero counts
- current_accumulator[1] + new[1], # Sum non-null counts
- ]
- def finalize(self, accumulator: List[int]) -> Optional[float]:
- if accumulator[1] == 0:
- return None
- return (accumulator[0] / accumulator[1]) * 100.0
- @PublicAPI(stability="alpha")
- class ApproximateQuantile(AggregateFnV2):
- def _require_datasketches(self):
- try:
- from datasketches import kll_floats_sketch # type: ignore[import]
- except ImportError as exc:
- raise ImportError(
- "ApproximateQuantile requires the `datasketches` package. "
- "Install it with `pip install datasketches`."
- ) from exc
- return kll_floats_sketch
- def __init__(
- self,
- on: str,
- quantiles: List[float],
- quantile_precision: int = 800,
- alias_name: Optional[str] = None,
- ):
- """
- Computes the approximate quantiles of a column by using a datasketches kll_floats_sketch.
- https://datasketches.apache.org/docs/KLL/KLLSketch.html
- The accuracy of the KLL quantile sketch is a function of the configured quantile precision, which also affects
- the overall size of the sketch.
- The KLL Sketch has absolute error. For example, a specified rank accuracy of 1% at the
- median (rank = 0.50) means that the true quantile (if you could extract it from the set)
- should be between getQuantile(0.49) and getQuantile(0.51). This same 1% error applied at a
- rank of 0.95 means that the true quantile should be between getQuantile(0.94) and getQuantile(0.96).
- In other words, the error is a fixed +/- epsilon for the entire range of ranks.
- Typical single-sided rank error by quantile_precision (use for getQuantile/getRank):
- - quantile_precision=100 → ~2.61%
- - quantile_precision=200 → ~1.33%
- - quantile_precision=400 → ~0.68%
- - quantile_precision=800 → ~0.35%
- See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
- Null values in the target column are ignored when constructing the sketch.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import ApproximateQuantile
- # Create a dataset with some values
- ds = ray.data.from_items(
- [{"value": 20.0}, {"value": 40.0}, {"value": 60.0},
- {"value": 80.0}, {"value": 100.0}]
- )
- result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9]))
- # Result: {'approx_quantile(value)': [20.0, 60.0, 100.0]}
- Args:
- on: The name of the column to calculate the quantile on. Must be a numeric column.
- quantiles: The list of quantiles to compute. Must be between 0 and 1 inclusive. For example, quantiles=[0.5] computes the median. Null entries in the source column are skipped.
- quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL); higher values yield lower error but use more memory. Defaults to 800. See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
- alias_name: Optional name for the resulting column. If not provided, defaults to "approx_quantile({column_name})".
- """
- self._sketch_cls = self._require_datasketches()
- self._quantiles = quantiles
- self._quantile_precision = quantile_precision
- super().__init__(
- alias_name if alias_name else f"approx_quantile({str(on)})",
- on=on,
- ignore_nulls=True,
- zero_factory=lambda: self.zero(quantile_precision).serialize(),
- )
- def zero(self, quantile_precision: int):
- return self._sketch_cls(k=quantile_precision)
- def aggregate_block(self, block: Block) -> bytes:
- block_acc = BlockAccessor.for_block(block)
- table = block_acc.to_arrow()
- column = table.column(self.get_target_column())
- sketch = self.zero(self._quantile_precision)
- for value in column:
- # we ignore nulls here
- if value.as_py() is not None:
- sketch.update(float(value.as_py()))
- return sketch.serialize()
- def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
- combined = self.zero(self._quantile_precision)
- combined.merge(self._sketch_cls.deserialize(current_accumulator))
- combined.merge(self._sketch_cls.deserialize(new))
- return combined.serialize()
- def finalize(self, accumulator: bytes) -> List[float]:
- return self._sketch_cls.deserialize(accumulator).get_quantiles(self._quantiles)
- @PublicAPI(stability="alpha")
- class ApproximateTopK(AggregateFnV2):
- def _require_datasketches(self):
- try:
- from datasketches import frequent_strings_sketch
- except ImportError as exc:
- raise ImportError(
- "ApproximateTopK requires the `datasketches` package. "
- "Install it with `pip install datasketches`."
- ) from exc
- return frequent_strings_sketch
- def __init__(
- self,
- on: str,
- k: int,
- log_capacity: int = 15,
- alias_name: Optional[str] = None,
- encode_lists: bool = False,
- ):
- """
- Computes the approximate top k items in a column by using a datasketches frequent_strings_sketch.
- https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html
- Guarantees:
- - Any item with true frequency > N / (2^log_capacity) is guaranteed to appear in the results
- - Reported counts may have an error of at most ± N / (2^log_capacity).
- If log_capacity is too small for your data:
- - Low-frequency items may be evicted from the sketch, potentially causing the top-k
- results to miss items that should appear in the output.
- - The error bounds increase, reducing the accuracy of the reported counts.
- Example:
- .. testcode::
- import ray
- from ray.data.aggregate import ApproximateTopK
- ds = ray.data.from_items([
- {"word": "apple"}, {"word": "banana"}, {"word": "apple"},
- {"word": "cherry"}, {"word": "apple"}
- ])
- result = ds.aggregate(ApproximateTopK(on="word", k=2))
- # Result: {'approx_topk(word)': [{'word': 'apple', 'count': 3}, {'word': 'banana', 'count': 1}]}
- Args:
- on: The name of the column to aggregate.
- k: The number of top items to return.
- log_capacity: Base 2 logarithm of the maximum size of the internal hash map.
- Higher values increase accuracy but use more memory. Defaults to 15.
- alias_name: The name of the aggregate. Defaults to None.
- encode_lists: If `True`, encode list elements. If `False`, encode
- whole lists (i.e., the entire list is considered as a single object).
- `False` by default. Note that this is a top-level flatten (not a recursive
- flatten) operation.
- """
- self.k = k
- self._log_capacity = log_capacity
- self._frequent_strings_sketch = self._require_datasketches()
- self._encode_lists = encode_lists
- super().__init__(
- alias_name if alias_name else f"approx_topk({str(on)})",
- on=on,
- ignore_nulls=True,
- zero_factory=lambda: self.zero(log_capacity).serialize(),
- )
- def zero(self, log_capacity: int):
- return self._frequent_strings_sketch(lg_max_k=log_capacity)
- def aggregate_block(self, block: Block) -> bytes:
- # Note: The datasketches Python bindings only expose frequent_strings_sketch
- # (not type-specific variants like frequent_ints_sketch). We use pickle
- # serialization as a workaround, which is less performant than native
- # type-specific sketches. Revisit if type-specific bindings are added.
- block_acc = BlockAccessor.for_block(block)
- table = block_acc.to_arrow()
- column = table.column(self.get_target_column())
- sketch = self.zero(self._log_capacity)
- for value in column:
- py_value = value.as_py()
- if self._encode_lists and isinstance(py_value, list):
- for item in py_value:
- if item is None:
- continue
- dump = pickle.dumps(item).hex()
- sketch.update(dump)
- elif py_value is not None:
- dump = pickle.dumps(py_value).hex()
- sketch.update(dump)
- return sketch.serialize()
- def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
- combined = self.zero(self._log_capacity)
- combined.merge(self._frequent_strings_sketch.deserialize(current_accumulator))
- combined.merge(self._frequent_strings_sketch.deserialize(new))
- return combined.serialize()
- def finalize(self, accumulator: bytes) -> List[Dict[str, Any]]:
- from datasketches import frequent_items_error_type
- column = self.get_target_column()
- frequent_items = self._frequent_strings_sketch.deserialize(
- accumulator
- ).get_frequent_items(frequent_items_error_type.NO_FALSE_NEGATIVES)
- return [
- {column: pickle.loads(bytes.fromhex(item[0])), "count": int(item[1])}
- for item in frequent_items[: self.k]
- ]
|