struct_namespace.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """Struct namespace for expression operations on struct-typed columns."""
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING
  5. import pyarrow
  6. import pyarrow.compute as pc
  7. from ray.data.datatype import DataType
  8. from ray.data.expressions import pyarrow_udf
  9. if TYPE_CHECKING:
  10. from ray.data.expressions import Expr, UDFExpr
  11. @dataclass
  12. class _StructNamespace:
  13. """Namespace for struct operations on expression columns.
  14. This namespace provides methods for operating on struct-typed columns using
  15. PyArrow compute functions.
  16. Example:
  17. >>> from ray.data.expressions import col
  18. >>> # Access a field using method
  19. >>> expr = col("user_record").struct.field("age")
  20. >>> # Access a field using bracket notation
  21. >>> expr = col("user_record").struct["age"]
  22. >>> # Access nested field
  23. >>> expr = col("user_record").struct["address"].struct["city"]
  24. """
  25. _expr: Expr
  26. def __getitem__(self, field_name: str) -> "UDFExpr":
  27. """Extract a field using bracket notation.
  28. Args:
  29. field_name: The name of the field to extract.
  30. Returns:
  31. UDFExpr that extracts the specified field from each struct.
  32. Example:
  33. >>> col("user").struct["age"] # Get age field # doctest: +SKIP
  34. >>> col("user").struct["address"].struct["city"] # Get nested city field # doctest: +SKIP
  35. """
  36. return self.field(field_name)
  37. def field(self, field_name: str) -> "UDFExpr":
  38. """Extract a field from a struct.
  39. Args:
  40. field_name: The name of the field to extract.
  41. Returns:
  42. UDFExpr that extracts the specified field from each struct.
  43. """
  44. # Infer return type from the struct field type
  45. return_dtype = DataType(object) # fallback
  46. if self._expr.data_type.is_arrow_type():
  47. arrow_type = self._expr.data_type.to_arrow_dtype()
  48. if pyarrow.types.is_struct(arrow_type):
  49. try:
  50. field_type = arrow_type.field(field_name).type
  51. return_dtype = DataType.from_arrow(field_type)
  52. except KeyError:
  53. # Field not found in schema, fallback to object
  54. pass
  55. @pyarrow_udf(return_dtype=return_dtype)
  56. def _struct_field(arr: pyarrow.Array) -> pyarrow.Array:
  57. return pc.struct_field(arr, field_name)
  58. return _struct_field(self._expr)