arr_namespace.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. """Array namespace for expression operations on array-typed columns."""
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING
  5. import pyarrow
  6. from ray.data.datatype import DataType
  7. from ray.data.expressions import pyarrow_udf
  8. if TYPE_CHECKING:
  9. from ray.data.expressions import Expr, UDFExpr
  10. @dataclass
  11. class _ArrayNamespace:
  12. """Namespace for array operations on expression columns.
  13. Example:
  14. >>> from ray.data.expressions import col
  15. >>> # Convert fixed-size lists to variable-length lists
  16. >>> expr = col("features").arr.to_list()
  17. """
  18. _expr: Expr
  19. def to_list(self) -> "UDFExpr":
  20. """Convert FixedSizeList columns into variable-length lists."""
  21. return_dtype = DataType(object)
  22. expr_dtype = self._expr.data_type
  23. if expr_dtype.is_list_type():
  24. arrow_type = expr_dtype.to_arrow_dtype()
  25. if pyarrow.types.is_fixed_size_list(arrow_type):
  26. return_dtype = DataType.from_arrow(pyarrow.list_(arrow_type.value_type))
  27. else:
  28. return_dtype = expr_dtype
  29. @pyarrow_udf(return_dtype=return_dtype)
  30. def _to_list(arr: pyarrow.Array) -> pyarrow.Array:
  31. arr_dtype = DataType.from_arrow(arr.type)
  32. if not arr_dtype.is_list_type():
  33. raise pyarrow.lib.ArrowInvalid(
  34. "to_list() can only be called on list-like columns, "
  35. f"but got {arr.type}"
  36. )
  37. if isinstance(arr.type, pyarrow.FixedSizeListType):
  38. return arr.cast(pyarrow.list_(arr.type.value_type))
  39. return arr
  40. return _to_list(self._expr)