expressions.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553
  1. from __future__ import annotations
  2. import functools
  3. from abc import ABC, abstractmethod
  4. from dataclasses import dataclass, field
  5. from enum import Enum
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. Dict,
  11. Generic,
  12. List,
  13. Optional,
  14. Tuple,
  15. Type,
  16. TypeVar,
  17. Union,
  18. )
  19. import pyarrow
  20. import pyarrow.compute as pc
  21. from ray.data.block import BatchColumn
  22. from ray.data.datatype import DataType
  23. from ray.util.annotations import DeveloperAPI, PublicAPI
  24. if TYPE_CHECKING:
  25. from ray.data.namespace_expressions.arr_namespace import _ArrayNamespace
  26. from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
  27. from ray.data.namespace_expressions.list_namespace import _ListNamespace
  28. from ray.data.namespace_expressions.string_namespace import _StringNamespace
  29. from ray.data.namespace_expressions.struct_namespace import _StructNamespace
  30. T = TypeVar("T")
  31. UDFCallable = Callable[..., "UDFExpr"]
  32. Decorated = Union[UDFCallable, Type[T]]
  33. @DeveloperAPI(stability="alpha")
  34. class Operation(Enum):
  35. """Enumeration of supported operations in expressions.
  36. This enum defines all the binary operations that can be performed
  37. between expressions, including arithmetic, comparison, and boolean operations.
  38. Attributes:
  39. ADD: Addition operation (+)
  40. SUB: Subtraction operation (-)
  41. MUL: Multiplication operation (*)
  42. DIV: Division operation (/)
  43. FLOORDIV: Floor division operation (//)
  44. GT: Greater than comparison (>)
  45. LT: Less than comparison (<)
  46. GE: Greater than or equal comparison (>=)
  47. LE: Less than or equal comparison (<=)
  48. EQ: Equality comparison (==)
  49. NE: Not equal comparison (!=)
  50. AND: Logical AND operation (&)
  51. OR: Logical OR operation (|)
  52. NOT: Logical NOT operation (~)
  53. IS_NULL: Check if value is null
  54. IS_NOT_NULL: Check if value is not null
  55. IN: Check if value is in a list
  56. NOT_IN: Check if value is not in a list
  57. """
  58. ADD = "add"
  59. SUB = "sub"
  60. MUL = "mul"
  61. DIV = "div"
  62. MOD = "mod"
  63. FLOORDIV = "floordiv"
  64. GT = "gt"
  65. LT = "lt"
  66. GE = "ge"
  67. LE = "le"
  68. EQ = "eq"
  69. NE = "ne"
  70. AND = "and"
  71. OR = "or"
  72. NOT = "not"
  73. IS_NULL = "is_null"
  74. IS_NOT_NULL = "is_not_null"
  75. IN = "in"
  76. NOT_IN = "not_in"
  77. class _ExprVisitor(ABC, Generic[T]):
  78. """Base visitor with generic dispatch for Ray Data expressions."""
  79. def visit(self, expr: "Expr") -> T:
  80. if isinstance(expr, ColumnExpr):
  81. return self.visit_column(expr)
  82. elif isinstance(expr, LiteralExpr):
  83. return self.visit_literal(expr)
  84. elif isinstance(expr, BinaryExpr):
  85. return self.visit_binary(expr)
  86. elif isinstance(expr, UnaryExpr):
  87. return self.visit_unary(expr)
  88. elif isinstance(expr, AliasExpr):
  89. return self.visit_alias(expr)
  90. elif isinstance(expr, UDFExpr):
  91. return self.visit_udf(expr)
  92. elif isinstance(expr, DownloadExpr):
  93. return self.visit_download(expr)
  94. elif isinstance(expr, StarExpr):
  95. return self.visit_star(expr)
  96. else:
  97. raise TypeError(f"Unsupported expression type for conversion: {type(expr)}")
  98. @abstractmethod
  99. def visit_column(self, expr: "ColumnExpr") -> T:
  100. pass
  101. @abstractmethod
  102. def visit_literal(self, expr: "LiteralExpr") -> T:
  103. pass
  104. @abstractmethod
  105. def visit_binary(self, expr: "BinaryExpr") -> T:
  106. pass
  107. @abstractmethod
  108. def visit_unary(self, expr: "UnaryExpr") -> T:
  109. pass
  110. @abstractmethod
  111. def visit_alias(self, expr: "AliasExpr") -> T:
  112. pass
  113. @abstractmethod
  114. def visit_udf(self, expr: "UDFExpr") -> T:
  115. pass
  116. @abstractmethod
  117. def visit_star(self, expr: "StarExpr") -> T:
  118. pass
  119. @abstractmethod
  120. def visit_download(self, expr: "DownloadExpr") -> T:
  121. pass
  122. class _PyArrowExpressionVisitor(_ExprVisitor["pyarrow.compute.Expression"]):
  123. """Visitor that converts Ray Data expressions to PyArrow compute expressions."""
  124. def visit_column(self, expr: "ColumnExpr") -> "pyarrow.compute.Expression":
  125. return pc.field(expr.name)
  126. def visit_literal(self, expr: "LiteralExpr") -> "pyarrow.compute.Expression":
  127. return pc.scalar(expr.value)
  128. def visit_binary(self, expr: "BinaryExpr") -> "pyarrow.compute.Expression":
  129. import pyarrow as pa
  130. if expr.op in (Operation.IN, Operation.NOT_IN):
  131. left = self.visit(expr.left)
  132. if isinstance(expr.right, LiteralExpr):
  133. right_value = expr.right.value
  134. right = (
  135. pa.array(right_value)
  136. if isinstance(right_value, list)
  137. else pa.array([right_value])
  138. )
  139. else:
  140. raise ValueError(
  141. f"is_in/not_in operations require the right operand to be a "
  142. f"literal list, got {type(expr.right).__name__}."
  143. )
  144. result = pc.is_in(left, right)
  145. return pc.invert(result) if expr.op == Operation.NOT_IN else result
  146. left = self.visit(expr.left)
  147. right = self.visit(expr.right)
  148. from ray.data._internal.planner.plan_expression.expression_evaluator import (
  149. _ARROW_EXPR_OPS_MAP,
  150. )
  151. if expr.op in _ARROW_EXPR_OPS_MAP:
  152. return _ARROW_EXPR_OPS_MAP[expr.op](left, right)
  153. raise ValueError(f"Unsupported binary operation for PyArrow: {expr.op}")
  154. def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression":
  155. operand = self.visit(expr.operand)
  156. from ray.data._internal.planner.plan_expression.expression_evaluator import (
  157. _ARROW_EXPR_OPS_MAP,
  158. )
  159. if expr.op in _ARROW_EXPR_OPS_MAP:
  160. return _ARROW_EXPR_OPS_MAP[expr.op](operand)
  161. raise ValueError(f"Unsupported unary operation for PyArrow: {expr.op}")
  162. def visit_alias(self, expr: "AliasExpr") -> "pyarrow.compute.Expression":
  163. return self.visit(expr.expr)
  164. def visit_udf(self, expr: "UDFExpr") -> "pyarrow.compute.Expression":
  165. raise TypeError("UDF expressions cannot be converted to PyArrow expressions")
  166. def visit_download(self, expr: "DownloadExpr") -> "pyarrow.compute.Expression":
  167. raise TypeError(
  168. "Download expressions cannot be converted to PyArrow expressions"
  169. )
  170. def visit_star(self, expr: "StarExpr") -> "pyarrow.compute.Expression":
  171. raise TypeError("Star expressions cannot be converted to PyArrow expressions")
  172. @DeveloperAPI(stability="alpha")
  173. @dataclass(frozen=True)
  174. class Expr(ABC):
  175. """Base class for all expression nodes.
  176. This is the abstract base class that all expression types inherit from.
  177. It provides operator overloads for building complex expressions using
  178. standard Python operators.
  179. Expressions form a tree structure where each node represents an operation
  180. or value. The tree can be evaluated against data batches to compute results.
  181. Example:
  182. >>> from ray.data.expressions import col, lit
  183. >>> # Create an expression tree: (col("x") + 5) * col("y")
  184. >>> expr = (col("x") + lit(5)) * col("y")
  185. >>> # This creates a BinaryExpr with operation=MUL
  186. >>> # left=BinaryExpr(op=ADD, left=ColumnExpr("x"), right=LiteralExpr(5))
  187. >>> # right=ColumnExpr("y")
  188. Note:
  189. This class should not be instantiated directly. Use the concrete
  190. subclasses like ColumnExpr, LiteralExpr, etc.
  191. """
  192. data_type: DataType
  193. @property
  194. def name(self) -> str | None:
  195. """Get the name associated with this expression.
  196. Returns:
  197. The name for expressions that have one (ColumnExpr, AliasExpr),
  198. None otherwise.
  199. """
  200. return None
  201. @abstractmethod
  202. def structurally_equals(self, other: Any) -> bool:
  203. """Compare two expression ASTs for structural equality."""
  204. raise NotImplementedError
  205. def to_pyarrow(self) -> "pyarrow.compute.Expression":
  206. """Convert this Ray Data expression to a PyArrow compute expression.
  207. Returns:
  208. A PyArrow compute expression equivalent to this Ray Data expression.
  209. Raises:
  210. ValueError: If the expression contains operations not supported by PyArrow.
  211. TypeError: If the expression type cannot be converted to PyArrow.
  212. """
  213. return _PyArrowExpressionVisitor().visit(self)
  214. def __repr__(self) -> str:
  215. """Return a tree-structured string representation of the expression.
  216. Returns:
  217. A multi-line string showing the expression tree structure using
  218. box-drawing characters for visual clarity.
  219. Example:
  220. >>> from ray.data.expressions import col, lit
  221. >>> expr = (col("x") + lit(5)) * col("y")
  222. >>> print(expr)
  223. MUL
  224. ├── left: ADD
  225. │ ├── left: COL('x')
  226. │ └── right: LIT(5)
  227. └── right: COL('y')
  228. """
  229. from ray.data._internal.planner.plan_expression.expression_visitors import (
  230. _TreeReprVisitor,
  231. )
  232. return _TreeReprVisitor().visit(self)
  233. def _bin(self, other: Any, op: Operation) -> "Expr":
  234. """Create a binary expression with the given operation.
  235. Args:
  236. other: The right operand expression or literal value
  237. op: The operation to perform
  238. Returns:
  239. A new BinaryExpr representing the operation
  240. Note:
  241. If other is not an Expr, it will be automatically converted to a LiteralExpr.
  242. """
  243. if not isinstance(other, Expr):
  244. other = LiteralExpr(other)
  245. return BinaryExpr(op, self, other)
  246. #
  247. # Arithmetic ops
  248. #
  249. def __add__(self, other: Any) -> "Expr":
  250. """Addition operator (+)."""
  251. return self._bin(other, Operation.ADD)
  252. def __radd__(self, other: Any) -> "Expr":
  253. """Reverse addition operator (for literal + expr)."""
  254. return LiteralExpr(other)._bin(self, Operation.ADD)
  255. def __sub__(self, other: Any) -> "Expr":
  256. """Subtraction operator (-)."""
  257. return self._bin(other, Operation.SUB)
  258. def __rsub__(self, other: Any) -> "Expr":
  259. """Reverse subtraction operator (for literal - expr)."""
  260. return LiteralExpr(other)._bin(self, Operation.SUB)
  261. def __mul__(self, other: Any) -> "Expr":
  262. """Multiplication operator (*)."""
  263. return self._bin(other, Operation.MUL)
  264. def __rmul__(self, other: Any) -> "Expr":
  265. """Reverse multiplication operator (for literal * expr)."""
  266. return LiteralExpr(other)._bin(self, Operation.MUL)
  267. def __mod__(self, other: Any):
  268. """Modulation operator (%)."""
  269. return self._bin(other, Operation.MOD)
  270. def __rmod__(self, other: Any):
  271. """Modulation operator (%)."""
  272. return LiteralExpr(other)._bin(self, Operation.MOD)
  273. def __truediv__(self, other: Any) -> "Expr":
  274. """Division operator (/)."""
  275. return self._bin(other, Operation.DIV)
  276. def __rtruediv__(self, other: Any) -> "Expr":
  277. """Reverse division operator (for literal / expr)."""
  278. return LiteralExpr(other)._bin(self, Operation.DIV)
  279. def __floordiv__(self, other: Any) -> "Expr":
  280. """Floor division operator (//)."""
  281. return self._bin(other, Operation.FLOORDIV)
  282. def __rfloordiv__(self, other: Any) -> "Expr":
  283. """Reverse floor division operator (for literal // expr)."""
  284. return LiteralExpr(other)._bin(self, Operation.FLOORDIV)
  285. # comparison
  286. def __gt__(self, other: Any) -> "Expr":
  287. """Greater than operator (>)."""
  288. return self._bin(other, Operation.GT)
  289. def __lt__(self, other: Any) -> "Expr":
  290. """Less than operator (<)."""
  291. return self._bin(other, Operation.LT)
  292. def __ge__(self, other: Any) -> "Expr":
  293. """Greater than or equal operator (>=)."""
  294. return self._bin(other, Operation.GE)
  295. def __le__(self, other: Any) -> "Expr":
  296. """Less than or equal operator (<=)."""
  297. return self._bin(other, Operation.LE)
  298. def __eq__(self, other: Any) -> "Expr":
  299. """Equality operator (==)."""
  300. return self._bin(other, Operation.EQ)
  301. def __ne__(self, other: Any) -> "Expr":
  302. """Not equal operator (!=)."""
  303. return self._bin(other, Operation.NE)
  304. # boolean
  305. def __and__(self, other: Any) -> "Expr":
  306. """Logical AND operator (&)."""
  307. return self._bin(other, Operation.AND)
  308. def __or__(self, other: Any) -> "Expr":
  309. """Logical OR operator (|)."""
  310. return self._bin(other, Operation.OR)
  311. def __invert__(self) -> "Expr":
  312. """Logical NOT operator (~)."""
  313. return UnaryExpr(Operation.NOT, self)
  314. # predicate methods
  315. def is_null(self) -> "Expr":
  316. """Check if the expression value is null."""
  317. return UnaryExpr(Operation.IS_NULL, self)
  318. def is_not_null(self) -> "Expr":
  319. """Check if the expression value is not null."""
  320. return UnaryExpr(Operation.IS_NOT_NULL, self)
  321. def is_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
  322. """Check if the expression value is in a list of values."""
  323. if not isinstance(values, Expr):
  324. values = LiteralExpr(values)
  325. return self._bin(values, Operation.IN)
  326. def not_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
  327. """Check if the expression value is not in a list of values."""
  328. if not isinstance(values, Expr):
  329. values = LiteralExpr(values)
  330. return self._bin(values, Operation.NOT_IN)
  331. def alias(self, name: str) -> "Expr":
  332. """Rename the expression.
  333. This method allows you to assign a new name to an expression result.
  334. This is particularly useful when you want to specify the output column name
  335. directly within the expression rather than as a separate parameter.
  336. Args:
  337. name: The new name for the expression
  338. Returns:
  339. An AliasExpr that wraps this expression with the specified name
  340. Example:
  341. >>> from ray.data.expressions import col, lit
  342. >>> # Create an expression with a new aliased name
  343. >>> expr = (col("price") * col("quantity")).alias("total")
  344. >>> # Can be used with Dataset operations that support named expressions
  345. """
  346. return AliasExpr(
  347. data_type=self.data_type, expr=self, _name=name, _is_rename=False
  348. )
  349. # rounding helpers
  350. def ceil(self) -> "UDFExpr":
  351. """Round values up to the nearest integer."""
  352. return _create_pyarrow_compute_udf(pc.ceil)(self)
  353. def floor(self) -> "UDFExpr":
  354. """Round values down to the nearest integer."""
  355. return _create_pyarrow_compute_udf(pc.floor)(self)
  356. def round(self) -> "UDFExpr":
  357. """Round values to the nearest integer using PyArrow semantics."""
  358. return _create_pyarrow_compute_udf(pc.round)(self)
  359. def trunc(self) -> "UDFExpr":
  360. """Truncate fractional values toward zero."""
  361. return _create_pyarrow_compute_udf(pc.trunc)(self)
  362. # logarithmic helpers
  363. def ln(self) -> "UDFExpr":
  364. """Compute the natural logarithm of the expression."""
  365. return _create_pyarrow_compute_udf(pc.ln, return_dtype=DataType.float64())(self)
  366. def log10(self) -> "UDFExpr":
  367. """Compute the base-10 logarithm of the expression."""
  368. return _create_pyarrow_compute_udf(pc.log10, return_dtype=DataType.float64())(
  369. self
  370. )
  371. def log2(self) -> "UDFExpr":
  372. """Compute the base-2 logarithm of the expression."""
  373. return _create_pyarrow_compute_udf(pc.log2, return_dtype=DataType.float64())(
  374. self
  375. )
  376. def exp(self) -> "UDFExpr":
  377. """Compute the natural exponential of the expression."""
  378. return _create_pyarrow_compute_udf(pc.exp, return_dtype=DataType.float64())(
  379. self
  380. )
  381. # trigonometric helpers
  382. def sin(self) -> "UDFExpr":
  383. """Compute the sine of the expression (in radians)."""
  384. return _create_pyarrow_compute_udf(pc.sin, return_dtype=DataType.float64())(
  385. self
  386. )
  387. def cos(self) -> "UDFExpr":
  388. """Compute the cosine of the expression (in radians)."""
  389. return _create_pyarrow_compute_udf(pc.cos, return_dtype=DataType.float64())(
  390. self
  391. )
  392. def tan(self) -> "UDFExpr":
  393. """Compute the tangent of the expression (in radians)."""
  394. return _create_pyarrow_compute_udf(pc.tan, return_dtype=DataType.float64())(
  395. self
  396. )
  397. def asin(self) -> "UDFExpr":
  398. """Compute the arcsine (inverse sine) of the expression, returning radians."""
  399. return _create_pyarrow_compute_udf(pc.asin, return_dtype=DataType.float64())(
  400. self
  401. )
  402. def acos(self) -> "UDFExpr":
  403. """Compute the arccosine (inverse cosine) of the expression, returning radians."""
  404. return _create_pyarrow_compute_udf(pc.acos, return_dtype=DataType.float64())(
  405. self
  406. )
  407. def atan(self) -> "UDFExpr":
  408. """Compute the arctangent (inverse tangent) of the expression, returning radians."""
  409. return _create_pyarrow_compute_udf(pc.atan, return_dtype=DataType.float64())(
  410. self
  411. )
  412. # arithmetic helpers
  413. def negate(self) -> "UDFExpr":
  414. """Compute the negation of the expression.
  415. Returns:
  416. A UDFExpr that computes the negation (multiplies values by -1).
  417. Example:
  418. >>> from ray.data.expressions import col
  419. >>> import ray
  420. >>> ds = ray.data.from_items([{"x": 5}, {"x": -3}])
  421. >>> ds = ds.with_column("neg_x", col("x").negate())
  422. >>> # Result: neg_x = [-5, 3]
  423. """
  424. return _create_pyarrow_compute_udf(pc.negate_checked)(self)
  425. def sign(self) -> "UDFExpr":
  426. """Compute the sign of the expression.
  427. Returns:
  428. A UDFExpr that returns -1 for negative values, 0 for zero, and 1 for positive values.
  429. Example:
  430. >>> from ray.data.expressions import col
  431. >>> import ray
  432. >>> ds = ray.data.from_items([{"x": 5}, {"x": -3}, {"x": 0}])
  433. >>> ds = ds.with_column("sign_x", col("x").sign())
  434. >>> # Result: sign_x = [1, -1, 0]
  435. """
  436. return _create_pyarrow_compute_udf(pc.sign)(self)
  437. def power(self, exponent: Any) -> "UDFExpr":
  438. """Raise the expression to the given power.
  439. Args:
  440. exponent: The exponent to raise the expression to.
  441. Returns:
  442. A UDFExpr that computes the power operation.
  443. Example:
  444. >>> from ray.data.expressions import col, lit
  445. >>> import ray
  446. >>> ds = ray.data.from_items([{"x": 2}, {"x": 3}])
  447. >>> ds = ds.with_column("x_squared", col("x").power(2))
  448. >>> # Result: x_squared = [4, 9]
  449. >>> ds = ds.with_column("x_cubed", col("x").power(3))
  450. >>> # Result: x_cubed = [8, 27]
  451. """
  452. return _create_pyarrow_compute_udf(pc.power)(self, exponent)
  453. def abs(self) -> "UDFExpr":
  454. """Compute the absolute value of the expression.
  455. Returns:
  456. A UDFExpr that computes the absolute value.
  457. Example:
  458. >>> from ray.data.expressions import col
  459. >>> import ray
  460. >>> ds = ray.data.from_items([{"x": 5}, {"x": -3}])
  461. >>> ds = ds.with_column("abs_x", col("x").abs())
  462. >>> # Result: abs_x = [5, 3]
  463. """
  464. return _create_pyarrow_compute_udf(pc.abs_checked)(self)
  465. @property
  466. def arr(self) -> "_ArrayNamespace":
  467. """Access array operations for this expression."""
  468. from ray.data.namespace_expressions.arr_namespace import _ArrayNamespace
  469. return _ArrayNamespace(self)
  470. @property
  471. def list(self) -> "_ListNamespace":
  472. """Access list operations for this expression.
  473. Returns:
  474. A _ListNamespace that provides list-specific operations for both
  475. PyArrow ``List`` and ``FixedSizeList`` columns.
  476. Example:
  477. >>> from ray.data.expressions import col
  478. >>> import ray
  479. >>> ds = ray.data.from_items([
  480. ... {"items": [1, 2, 3]},
  481. ... {"items": [4, 5]}
  482. ... ])
  483. >>> ds = ds.with_column("num_items", col("items").list.len())
  484. >>> ds = ds.with_column("first_item", col("items").list[0])
  485. >>> ds = ds.with_column("slice", col("items").list[1:3])
  486. """
  487. from ray.data.namespace_expressions.list_namespace import _ListNamespace
  488. return _ListNamespace(self)
  489. @property
  490. def str(self) -> "_StringNamespace":
  491. """Access string operations for this expression.
  492. Returns:
  493. A _StringNamespace that provides string-specific operations.
  494. Example:
  495. >>> from ray.data.expressions import col
  496. >>> import ray
  497. >>> ds = ray.data.from_items([
  498. ... {"name": "Alice"},
  499. ... {"name": "Bob"}
  500. ... ])
  501. >>> ds = ds.with_column("upper_name", col("name").str.upper())
  502. >>> ds = ds.with_column("name_len", col("name").str.len())
  503. >>> ds = ds.with_column("starts_a", col("name").str.starts_with("A"))
  504. """
  505. from ray.data.namespace_expressions.string_namespace import _StringNamespace
  506. return _StringNamespace(self)
  507. @property
  508. def struct(self) -> "_StructNamespace":
  509. """Access struct operations for this expression.
  510. Returns:
  511. A _StructNamespace that provides struct-specific operations.
  512. Example:
  513. >>> from ray.data.expressions import col
  514. >>> import ray
  515. >>> import pyarrow as pa
  516. >>> ds = ray.data.from_arrow(pa.table({
  517. ... "user": pa.array([
  518. ... {"name": "Alice", "age": 30}
  519. ... ], type=pa.struct([
  520. ... pa.field("name", pa.string()),
  521. ... pa.field("age", pa.int32())
  522. ... ]))
  523. ... }))
  524. >>> ds = ds.with_column("age", col("user").struct["age"]) # doctest: +SKIP
  525. """
  526. from ray.data.namespace_expressions.struct_namespace import _StructNamespace
  527. return _StructNamespace(self)
  528. @property
  529. def dt(self) -> "_DatetimeNamespace":
  530. """Access datetime operations for this expression."""
  531. from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
  532. return _DatetimeNamespace(self)
  533. def _unalias(self) -> "Expr":
  534. return self
  535. @DeveloperAPI(stability="alpha")
  536. @dataclass(frozen=True, eq=False, repr=False)
  537. class ColumnExpr(Expr):
  538. """Expression that references a column by name.
  539. This expression type represents a reference to an existing column
  540. in the dataset. When evaluated, it returns the values from the
  541. specified column.
  542. Args:
  543. name: The name of the column to reference
  544. Example:
  545. >>> from ray.data.expressions import col
  546. >>> # Reference the "age" column
  547. >>> age_expr = col("age") # Creates ColumnExpr(name="age")
  548. """
  549. _name: str
  550. data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
  551. @property
  552. def name(self) -> str:
  553. """Get the column name."""
  554. return self._name
  555. def _rename(self, name: str):
  556. return AliasExpr(self.data_type, self, name, _is_rename=True)
  557. def structurally_equals(self, other: Any) -> bool:
  558. return isinstance(other, ColumnExpr) and self.name == other.name
  559. @DeveloperAPI(stability="alpha")
  560. @dataclass(frozen=True, eq=False, repr=False)
  561. class LiteralExpr(Expr):
  562. """Expression that represents a constant scalar value.
  563. This expression type represents a literal value that will be broadcast
  564. to all rows when evaluated. The value can be any Python object.
  565. Args:
  566. value: The constant value to represent
  567. Example:
  568. >>> from ray.data.expressions import lit
  569. >>> import numpy as np
  570. >>> # Create a literal value
  571. >>> five = lit(5) # Creates LiteralExpr(value=5)
  572. >>> name = lit("John") # Creates LiteralExpr(value="John")
  573. >>> numpy_val = lit(np.int32(42)) # Creates LiteralExpr with numpy type
  574. """
  575. value: Any
  576. data_type: DataType = field(init=False)
  577. def __post_init__(self):
  578. # Infer the type from the value using DataType.infer_dtype
  579. inferred_dtype = DataType.infer_dtype(self.value)
  580. # Use object.__setattr__ since the dataclass is frozen
  581. object.__setattr__(self, "data_type", inferred_dtype)
  582. def structurally_equals(self, other: Any) -> bool:
  583. return (
  584. isinstance(other, LiteralExpr)
  585. and self.value == other.value
  586. and type(self.value) is type(other.value)
  587. )
  588. @DeveloperAPI(stability="alpha")
  589. @dataclass(frozen=True, eq=False, repr=False)
  590. class BinaryExpr(Expr):
  591. """Expression that represents a binary operation between two expressions.
  592. This expression type represents an operation with two operands (left and right).
  593. The operation is specified by the `op` field, which must be one of the
  594. supported operations from the Operation enum.
  595. Args:
  596. op: The operation to perform (from Operation enum)
  597. left: The left operand expression
  598. right: The right operand expression
  599. Example:
  600. >>> from ray.data.expressions import col, lit, Operation
  601. >>> # Manually create a binary expression (usually done via operators)
  602. >>> expr = BinaryExpr(Operation.ADD, col("x"), lit(5))
  603. >>> # This is equivalent to: col("x") + lit(5)
  604. """
  605. op: Operation
  606. left: Expr
  607. right: Expr
  608. data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
  609. def structurally_equals(self, other: Any) -> bool:
  610. return (
  611. isinstance(other, BinaryExpr)
  612. and self.op is other.op
  613. and self.left.structurally_equals(other.left)
  614. and self.right.structurally_equals(other.right)
  615. )
  616. @DeveloperAPI(stability="alpha")
  617. @dataclass(frozen=True, eq=False, repr=False)
  618. class UnaryExpr(Expr):
  619. """Expression that represents a unary operation on a single expression.
  620. This expression type represents an operation with one operand.
  621. Common unary operations include logical NOT, IS NULL, IS NOT NULL, etc.
  622. Args:
  623. op: The operation to perform (from Operation enum)
  624. operand: The operand expression
  625. Example:
  626. >>> from ray.data.expressions import col
  627. >>> # Check if a column is null
  628. >>> expr = col("age").is_null() # Creates UnaryExpr(IS_NULL, col("age"))
  629. >>> # Logical not
  630. >>> expr = ~(col("active")) # Creates UnaryExpr(NOT, col("active"))
  631. """
  632. op: Operation
  633. operand: Expr
  634. # Default to bool return dtype for unary operations like is_null() and NOT.
  635. # This enables chaining operations such as col("x").is_not_null().alias("valid"),
  636. # where downstream expressions (like AliasExpr) need the data type.
  637. data_type: DataType = field(default_factory=lambda: DataType.bool(), init=False)
  638. def structurally_equals(self, other: Any) -> bool:
  639. return (
  640. isinstance(other, UnaryExpr)
  641. and self.op is other.op
  642. and self.operand.structurally_equals(other.operand)
  643. )
  644. @dataclass(frozen=True)
  645. class _CallableClassSpec:
  646. """Specification for a callable class UDF.
  647. This dataclass captures the class type and constructor arguments needed
  648. to instantiate a callable class UDF on an actor. It consolidates the
  649. callable class metadata that was previously spread across multiple fields.
  650. Attributes:
  651. cls: The original callable class type
  652. args: Positional arguments for the constructor
  653. kwargs: Keyword arguments for the constructor
  654. _cached_key: Pre-computed key that survives serialization
  655. """
  656. cls: type
  657. args: Tuple[Any, ...] = ()
  658. kwargs: Dict[str, Any] = field(default_factory=dict)
  659. _cached_key: Optional[Tuple] = field(default=None, compare=False, repr=False)
  660. def __post_init__(self):
  661. """Pre-compute and cache the key at construction time.
  662. This ensures the same key survives serialization, since the cached
  663. key tuple (containing the already-computed repr strings) gets pickled
  664. and unpickled as-is.
  665. """
  666. if self._cached_key is None:
  667. class_id = f"{self.cls.__module__}.{self.cls.__qualname__}"
  668. try:
  669. key = (
  670. class_id,
  671. self.args,
  672. tuple(sorted(self.kwargs.items())),
  673. )
  674. # Verify the key is actually hashable (args may contain lists)
  675. hash(key)
  676. except TypeError:
  677. # Fallback for unhashable args/kwargs - use repr for comparison
  678. key = (class_id, repr(self.args), repr(self.kwargs))
  679. # Use object.__setattr__ since dataclass is frozen
  680. object.__setattr__(self, "_cached_key", key)
  681. def make_key(self) -> Tuple:
  682. """Return the pre-computed hashable key for UDF instance lookup.
  683. The key uniquely identifies a UDF by its class and constructor arguments.
  684. This ensures that the same class with different constructor args
  685. (e.g., Multiplier(2) vs Multiplier(3)) are treated as distinct UDFs.
  686. Returns:
  687. A hashable tuple that uniquely identifies this UDF configuration.
  688. """
  689. return self._cached_key
  690. class _CallableClassUDF:
  691. """A wrapper that makes callable class UDFs appear as regular functions.
  692. This class wraps callable class UDFs for use in expressions. It provides
  693. an `init()` method that should be called at actor startup via `init_fn`
  694. to instantiate the underlying class before any blocks are processed.
  695. Key responsibilities:
  696. 1. Store the callable class and constructor arguments
  697. 2. Provide init() for actor startup initialization
  698. 3. Handle async bridging for coroutine/async generator UDFs
  699. 4. Reuse the same instance across all calls (actor semantics)
  700. Example:
  701. >>> @udf(return_dtype=DataType.int32())
  702. ... class AddOffset:
  703. ... def __init__(self, offset=1):
  704. ... self.offset = offset
  705. ... def __call__(self, x):
  706. ... return pc.add(x, self.offset)
  707. >>>
  708. >>> add_five = AddOffset(5) # Creates _CallableClassUDF internally
  709. >>> expr = add_five(col("value")) # Creates UDFExpr with fn=_CallableClassUDF
  710. """
  711. def __init__(
  712. self,
  713. cls: type,
  714. ctor_args: Tuple[Any, ...],
  715. ctor_kwargs: Dict[str, Any],
  716. return_dtype: DataType,
  717. ):
  718. """Initialize the _CallableClassUDF wrapper.
  719. Args:
  720. cls: The original callable class
  721. ctor_args: Constructor positional arguments
  722. ctor_kwargs: Constructor keyword arguments
  723. return_dtype: The return data type for schema inference
  724. """
  725. self._cls = cls
  726. self._ctor_args = ctor_args
  727. self._ctor_kwargs = ctor_kwargs
  728. self._return_dtype = return_dtype
  729. # Instance created by init() at actor startup
  730. self._instance = None
  731. # Cache the spec to avoid creating new instances on each access
  732. self._callable_class_spec = _CallableClassSpec(
  733. cls=cls,
  734. args=ctor_args,
  735. kwargs=ctor_kwargs,
  736. )
  737. @property
  738. def __name__(self) -> str:
  739. """Return the original class name for error messages."""
  740. return self._cls.__name__
  741. @property
  742. def callable_class_spec(self) -> _CallableClassSpec:
  743. """Return the callable class spec for this UDF.
  744. Used for deduplication when the same UDF appears multiple times
  745. in an expression tree.
  746. """
  747. return self._callable_class_spec
  748. def init(self) -> None:
  749. """Initialize the UDF instance. Called at actor startup via init_fn.
  750. This ensures the callable class is instantiated before any blocks
  751. are processed, matching the behavior of map_batches callable classes.
  752. """
  753. if self._instance is None:
  754. self._instance = self._cls(*self._ctor_args, **self._ctor_kwargs)
  755. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  756. """Call the UDF instance.
  757. Args:
  758. *args: Evaluated expression arguments (PyArrow arrays, etc.)
  759. **kwargs: Evaluated expression keyword arguments
  760. Returns:
  761. The result of calling the UDF instance
  762. Raises:
  763. RuntimeError: If init() was not called before __call__
  764. """
  765. if self._instance is None:
  766. raise RuntimeError(
  767. f"_CallableClassUDF '{self._cls.__name__}' was not initialized. "
  768. f"init() must be called before __call__. This typically happens "
  769. f"via init_fn at actor startup."
  770. )
  771. from ray.data.util.expression_utils import _call_udf_instance_with_async_bridge
  772. # Call instance directly, handling async if needed
  773. return _call_udf_instance_with_async_bridge(self._instance, *args, **kwargs)
  774. @DeveloperAPI(stability="alpha")
  775. @dataclass(frozen=True, eq=False, repr=False)
  776. class UDFExpr(Expr):
  777. """Expression that represents a user-defined function call.
  778. This expression type wraps a UDF with schema inference capabilities,
  779. allowing UDFs to be used seamlessly within the expression system.
  780. UDFs operate on batches of data, where each column argument is passed
  781. as a PyArrow Array containing multiple values from that column across the batch.
  782. Args:
  783. fn: The user-defined function to call. For callable classes, this is an
  784. _CallableClassUDF instance that handles lazy instantiation internally.
  785. args: List of argument expressions (positional arguments)
  786. kwargs: Dictionary of keyword argument expressions
  787. Example:
  788. >>> from ray.data.expressions import col, udf
  789. >>> import pyarrow as pa
  790. >>> import pyarrow.compute as pc
  791. >>> from ray.data.datatype import DataType
  792. >>>
  793. >>> @udf(return_dtype=DataType.int32())
  794. ... def add_one(x: pa.Array) -> pa.Array:
  795. ... return pc.add(x, 1)
  796. >>>
  797. >>> # Use in expressions
  798. >>> expr = add_one(col("value"))
  799. >>> # Callable class example
  800. >>> @udf(return_dtype=DataType.int32())
  801. ... class AddOffset:
  802. ... def __init__(self, offset=1):
  803. ... self.offset = offset
  804. ... def __call__(self, x: pa.Array) -> pa.Array:
  805. ... return pc.add(x, self.offset)
  806. >>>
  807. >>> # Use callable class
  808. >>> add_five = AddOffset(5)
  809. >>> expr = add_five(col("value"))
  810. """
  811. fn: Callable[..., BatchColumn] # Can be regular function OR _CallableClassUDF
  812. args: List[Expr]
  813. kwargs: Dict[str, Expr]
  814. @property
  815. def callable_class_spec(self) -> Optional[_CallableClassSpec]:
  816. """Return callable_class_spec if fn is an _CallableClassUDF, else None.
  817. This property maintains backward compatibility with code that checks
  818. for callable_class_spec.
  819. """
  820. if isinstance(self.fn, _CallableClassUDF):
  821. return self.fn.callable_class_spec
  822. return None
  823. def structurally_equals(self, other: Any) -> bool:
  824. if not isinstance(other, UDFExpr):
  825. return False
  826. # For callable class UDFs (_CallableClassUDF), compare the callable_class_spec.
  827. # For regular function UDFs, compare fn directly.
  828. if isinstance(self.fn, _CallableClassUDF):
  829. if not isinstance(other.fn, _CallableClassUDF):
  830. return False
  831. if self.fn.callable_class_spec != other.fn.callable_class_spec:
  832. return False
  833. else:
  834. if self.fn != other.fn:
  835. return False
  836. return (
  837. len(self.args) == len(other.args)
  838. and all(a.structurally_equals(b) for a, b in zip(self.args, other.args))
  839. and self.kwargs.keys() == other.kwargs.keys()
  840. and all(
  841. self.kwargs[k].structurally_equals(other.kwargs[k])
  842. for k in self.kwargs.keys()
  843. )
  844. )
  845. def _create_udf_callable(
  846. fn: Callable[..., BatchColumn],
  847. return_dtype: DataType,
  848. ) -> Callable[..., UDFExpr]:
  849. """Create a callable that generates UDFExpr when called with expressions.
  850. Args:
  851. fn: The user-defined function to wrap. Can be a regular function
  852. or an _CallableClassUDF instance (for callable classes).
  853. return_dtype: The return data type of the UDF
  854. Returns:
  855. A callable that creates UDFExpr instances when called with expressions
  856. """
  857. def udf_callable(*args, **kwargs) -> UDFExpr:
  858. # Convert arguments to expressions if they aren't already
  859. expr_args = []
  860. for arg in args:
  861. if isinstance(arg, Expr):
  862. expr_args.append(arg)
  863. else:
  864. expr_args.append(LiteralExpr(arg))
  865. expr_kwargs = {}
  866. for k, v in kwargs.items():
  867. if isinstance(v, Expr):
  868. expr_kwargs[k] = v
  869. else:
  870. expr_kwargs[k] = LiteralExpr(v)
  871. return UDFExpr(
  872. fn=fn,
  873. args=expr_args,
  874. kwargs=expr_kwargs,
  875. data_type=return_dtype,
  876. )
  877. # Preserve original function metadata
  878. functools.update_wrapper(udf_callable, fn)
  879. # Store the original function for access if needed
  880. udf_callable._original_fn = fn
  881. return udf_callable
  882. @PublicAPI(stability="alpha")
  883. def udf(return_dtype: DataType) -> Callable[..., UDFExpr]:
  884. """
  885. Decorator to convert a UDF into an expression-compatible function.
  886. This decorator allows UDFs to be used seamlessly within the expression system,
  887. enabling schema inference and integration with other expressions.
  888. IMPORTANT: UDFs operate on batches of data, not individual rows. When your UDF
  889. is called, each column argument will be passed as a PyArrow Array containing
  890. multiple values from that column across the batch. Under the hood, when working
  891. with multiple columns, they get translated to PyArrow arrays (one array per column).
  892. Args:
  893. return_dtype: The data type of the return value of the UDF
  894. Returns:
  895. A callable that creates UDFExpr instances when called with expressions
  896. Example:
  897. >>> from ray.data.expressions import col, udf
  898. >>> import pyarrow as pa
  899. >>> import pyarrow.compute as pc
  900. >>> import ray
  901. >>>
  902. >>> # UDF that operates on a batch of values (PyArrow Array)
  903. >>> @udf(return_dtype=DataType.int32())
  904. ... def add_one(x: pa.Array) -> pa.Array:
  905. ... return pc.add(x, 1) # Vectorized operation on the entire Array
  906. >>>
  907. >>> # UDF that combines multiple columns (each as a PyArrow Array)
  908. >>> @udf(return_dtype=DataType.string())
  909. ... def format_name(first: pa.Array, last: pa.Array) -> pa.Array:
  910. ... return pc.binary_join_element_wise(first, last, " ") # Vectorized string concatenation
  911. >>>
  912. >>> # Callable class UDF
  913. >>> @udf(return_dtype=DataType.int32())
  914. ... class AddOffset:
  915. ... def __init__(self, offset=1):
  916. ... self.offset = offset
  917. ... def __call__(self, x: pa.Array) -> pa.Array:
  918. ... return pc.add(x, self.offset)
  919. >>>
  920. >>> # Use in dataset operations
  921. >>> ds = ray.data.from_items([
  922. ... {"value": 5, "first": "John", "last": "Doe"},
  923. ... {"value": 10, "first": "Jane", "last": "Smith"}
  924. ... ])
  925. >>>
  926. >>> # Single column transformation (operates on batches)
  927. >>> ds_incremented = ds.with_column("value_plus_one", add_one(col("value")))
  928. >>>
  929. >>> # Multi-column transformation (each column becomes a PyArrow Array)
  930. >>> ds_formatted = ds.with_column("full_name", format_name(col("first"), col("last")))
  931. >>>
  932. >>> # Callable class usage
  933. >>> add_five = AddOffset(5)
  934. >>> ds_with_offset = ds.with_column("value_plus_five", add_five(col("value")))
  935. >>>
  936. >>> # Can also be used in complex expressions
  937. >>> ds_complex = ds.with_column("doubled_plus_one", add_one(col("value")) * 2)
  938. """
  939. def decorator(
  940. func_or_class: Union[Callable[..., BatchColumn], Type[T]]
  941. ) -> Decorated:
  942. # Check if this is a callable class (has __call__ method defined)
  943. if isinstance(func_or_class, type) and issubclass(func_or_class, Callable):
  944. # Wrapper that delays instantiation and returns expressions instead of executing.
  945. # Without this, MyClass(args) would instantiate on the driver and
  946. # instance(col(...)) would try to execute rather than building an expression.
  947. class ExpressionAwareCallableClass:
  948. """Intercepts callable class instantiation to delay until actor execution.
  949. Allows natural syntax like:
  950. add_five = AddOffset(5)
  951. ds.with_column("result", add_five(col("x")))
  952. When instantiated, creates an _CallableClassUDF that is completely
  953. self-contained - it handles lazy instantiation and async bridging
  954. internally. From the planner's perspective, this is just a regular
  955. callable function.
  956. """
  957. def __init__(self, *args, **kwargs):
  958. # Create an _CallableClassUDF that is self-contained.
  959. # It lazily instantiates the class on first call (on the worker)
  960. # and handles async bridging internally.
  961. self._expr_udf = _CallableClassUDF(
  962. cls=func_or_class,
  963. ctor_args=args,
  964. ctor_kwargs=kwargs,
  965. return_dtype=return_dtype,
  966. )
  967. def __call__(self, *call_args, **call_kwargs):
  968. # Create UDFExpr with fn=_CallableClassUDF
  969. # The _CallableClassUDF is self-contained - no external setup needed
  970. return _create_udf_callable(
  971. self._expr_udf,
  972. return_dtype,
  973. )(*call_args, **call_kwargs)
  974. # Preserve the original class name and module for better error messages
  975. ExpressionAwareCallableClass.__name__ = func_or_class.__name__
  976. ExpressionAwareCallableClass.__qualname__ = func_or_class.__qualname__
  977. ExpressionAwareCallableClass.__module__ = func_or_class.__module__
  978. return ExpressionAwareCallableClass
  979. else:
  980. # Regular function
  981. return _create_udf_callable(func_or_class, return_dtype)
  982. return decorator
  983. def _create_pyarrow_wrapper(
  984. fn: Callable[..., BatchColumn]
  985. ) -> Callable[..., BatchColumn]:
  986. """Wrap a PyArrow compute function to auto-convert inputs to PyArrow format.
  987. This wrapper ensures that pandas Series and numpy arrays are converted to
  988. PyArrow Arrays before being passed to the function, enabling PyArrow compute
  989. functions to work seamlessly with any block format.
  990. Args:
  991. fn: The PyArrow compute function to wrap
  992. Returns:
  993. A wrapped function that handles format conversion
  994. """
  995. @functools.wraps(fn)
  996. def arrow_wrapper(*args, **kwargs):
  997. import numpy as np
  998. import pandas as pd
  999. import pyarrow as pa
  1000. def to_arrow(val):
  1001. """Convert a value to PyArrow Array if needed."""
  1002. if isinstance(val, (pa.Array, pa.ChunkedArray)):
  1003. return val, False
  1004. elif isinstance(val, pd.Series):
  1005. return pa.Array.from_pandas(val), True
  1006. elif isinstance(val, np.ndarray):
  1007. return pa.array(val), False
  1008. else:
  1009. return val, False
  1010. # Convert inputs to PyArrow and track pandas flags
  1011. args_results = [to_arrow(arg) for arg in args]
  1012. kwargs_results = {k: to_arrow(v) for k, v in kwargs.items()}
  1013. converted_args = [v[0] for v in args_results]
  1014. converted_kwargs = {k: v[0] for k, v in kwargs_results.items()}
  1015. input_was_pandas = any(v[1] for v in args_results) or any(
  1016. v[1] for v in kwargs_results.values()
  1017. )
  1018. # Call function with converted inputs
  1019. result = fn(*converted_args, **converted_kwargs)
  1020. # Convert result back to pandas if input was pandas
  1021. if input_was_pandas and isinstance(result, (pa.Array, pa.ChunkedArray)):
  1022. result = result.to_pandas()
  1023. return result
  1024. return arrow_wrapper
  1025. @PublicAPI(stability="alpha")
  1026. def pyarrow_udf(return_dtype: DataType) -> Callable[..., UDFExpr]:
  1027. """Decorator for PyArrow compute functions with automatic format conversion.
  1028. This decorator wraps PyArrow compute functions to automatically convert pandas
  1029. Series and numpy arrays to PyArrow Arrays, ensuring the function works seamlessly
  1030. regardless of the underlying block format (pandas, arrow, or items).
  1031. Used internally by namespace methods (list, str, struct) that wrap PyArrow
  1032. compute functions.
  1033. Args:
  1034. return_dtype: The data type of the return value
  1035. Returns:
  1036. A callable that creates UDFExpr instances with automatic conversion
  1037. """
  1038. def decorator(func: Callable[..., BatchColumn]) -> Callable[..., UDFExpr]:
  1039. # Wrap the function with PyArrow conversion logic
  1040. wrapped_fn = _create_pyarrow_wrapper(func)
  1041. # Create UDFExpr callable using the wrapped function
  1042. return _create_udf_callable(wrapped_fn, return_dtype)
  1043. return decorator
  1044. def _create_pyarrow_compute_udf(
  1045. pc_func: Callable[..., pyarrow.Array],
  1046. return_dtype: DataType | None = None,
  1047. ) -> Callable[..., "UDFExpr"]:
  1048. """Create an expression UDF backed by a PyArrow compute function."""
  1049. def wrapper(expr: "Expr", *positional: Any, **kwargs: Any) -> "UDFExpr":
  1050. @pyarrow_udf(return_dtype=return_dtype or expr.data_type)
  1051. def udf(arr: pyarrow.Array) -> pyarrow.Array:
  1052. return pc_func(arr, *positional, **kwargs)
  1053. return udf(expr)
  1054. return wrapper
  1055. @DeveloperAPI(stability="alpha")
  1056. @dataclass(frozen=True, eq=False, repr=False)
  1057. class DownloadExpr(Expr):
  1058. """Expression that represents a download operation."""
  1059. uri_column_name: str
  1060. filesystem: "pyarrow.fs.FileSystem" = None
  1061. data_type: DataType = field(default_factory=lambda: DataType.binary(), init=False)
  1062. def structurally_equals(self, other: Any) -> bool:
  1063. return (
  1064. isinstance(other, DownloadExpr)
  1065. and self.uri_column_name == other.uri_column_name
  1066. )
  1067. @DeveloperAPI(stability="alpha")
  1068. @dataclass(frozen=True, eq=False, repr=False)
  1069. class AliasExpr(Expr):
  1070. """Expression that represents an alias for an expression."""
  1071. expr: Expr
  1072. _name: str
  1073. _is_rename: bool
  1074. @property
  1075. def name(self) -> str:
  1076. """Get the alias name."""
  1077. return self._name
  1078. def alias(self, name: str) -> "Expr":
  1079. # Always unalias before creating new one
  1080. return AliasExpr(
  1081. self.expr.data_type, self.expr, _name=name, _is_rename=self._is_rename
  1082. )
  1083. def _unalias(self) -> "Expr":
  1084. return self.expr
  1085. def structurally_equals(self, other: Any) -> bool:
  1086. return (
  1087. isinstance(other, AliasExpr)
  1088. and self.expr.structurally_equals(other.expr)
  1089. and self.name == other.name
  1090. and self._is_rename == self._is_rename
  1091. )
  1092. @DeveloperAPI(stability="alpha")
  1093. @dataclass(frozen=True, eq=False, repr=False)
  1094. class StarExpr(Expr):
  1095. """Expression that represents all columns from the input.
  1096. This is a special expression used in projections to indicate that
  1097. all existing columns should be preserved at this position in the output.
  1098. It's typically used internally by operations like with_column() and
  1099. rename_columns() to maintain existing columns.
  1100. Example:
  1101. When with_column("new_col", expr) is called, it creates:
  1102. Project(exprs=[star(), expr.alias("new_col")])
  1103. This means: keep all existing columns, then add/overwrite "new_col"
  1104. """
  1105. # TODO: Add UnresolvedExpr. Both StarExpr and UnresolvedExpr won't have a defined data_type.
  1106. data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
  1107. def structurally_equals(self, other: Any) -> bool:
  1108. return isinstance(other, StarExpr)
  1109. @PublicAPI(stability="beta")
  1110. def col(name: str) -> ColumnExpr:
  1111. """
  1112. Reference an existing column by name.
  1113. This is the primary way to reference columns in expressions.
  1114. The returned expression will extract values from the specified
  1115. column when evaluated.
  1116. Args:
  1117. name: The name of the column to reference
  1118. Returns:
  1119. A ColumnExpr that references the specified column
  1120. Example:
  1121. >>> from ray.data.expressions import col
  1122. >>> # Reference columns in an expression
  1123. >>> expr = col("price") * col("quantity")
  1124. >>>
  1125. >>> # Use with Dataset.with_column()
  1126. >>> import ray
  1127. >>> ds = ray.data.from_items([{"price": 10, "quantity": 2}])
  1128. >>> ds = ds.with_column("total", col("price") * col("quantity"))
  1129. """
  1130. return ColumnExpr(name)
  1131. @PublicAPI(stability="beta")
  1132. def lit(value: Any) -> LiteralExpr:
  1133. """
  1134. Create a literal expression from a constant value.
  1135. This creates an expression that represents a constant scalar value.
  1136. The value will be broadcast to all rows when the expression is evaluated.
  1137. Args:
  1138. value: The constant value to represent. Can be any Python object
  1139. (int, float, str, bool, etc.)
  1140. Returns:
  1141. A LiteralExpr containing the specified value
  1142. Example:
  1143. >>> from ray.data.expressions import col, lit
  1144. >>> # Create literals of different types
  1145. >>> five = lit(5)
  1146. >>> pi = lit(3.14159)
  1147. >>> name = lit("Alice")
  1148. >>> flag = lit(True)
  1149. >>>
  1150. >>> # Use in expressions
  1151. >>> expr = col("age") + lit(1) # Add 1 to age column
  1152. >>>
  1153. >>> # Use with Dataset.with_column()
  1154. >>> import ray
  1155. >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}])
  1156. >>> ds = ds.with_column("age_plus_one", col("age") + lit(1))
  1157. """
  1158. return LiteralExpr(value)
  1159. # TODO remove
  1160. @DeveloperAPI(stability="alpha")
  1161. def star() -> StarExpr:
  1162. """
  1163. References all input columns from the input.
  1164. This is a special expression used in projections to preserve all
  1165. existing columns. It's typically used with operations that want to
  1166. add or modify columns while keeping the rest.
  1167. Returns:
  1168. A StarExpr that represents all input columns.
  1169. """
  1170. return StarExpr()
  1171. @PublicAPI(stability="alpha")
  1172. def download(
  1173. uri_column_name: str,
  1174. *,
  1175. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  1176. ) -> DownloadExpr:
  1177. """
  1178. Create a download expression that downloads content from URIs.
  1179. This creates an expression that will download bytes from URIs stored in
  1180. a specified column. When evaluated, it will fetch the content from each URI
  1181. and return the downloaded bytes.
  1182. Args:
  1183. uri_column_name: The name of the column containing URIs to download from
  1184. filesystem: PyArrow filesystem to use for reading remote files.
  1185. If None, the filesystem is auto-detected from the path scheme.
  1186. Returns:
  1187. A DownloadExpr that will download content from the specified URI column
  1188. Example:
  1189. >>> from ray.data.expressions import download
  1190. >>> import ray
  1191. >>> # Create dataset with URIs
  1192. >>> ds = ray.data.from_items([
  1193. ... {"uri": "s3://bucket/file1.jpg", "id": "1"},
  1194. ... {"uri": "s3://bucket/file2.jpg", "id": "2"}
  1195. ... ])
  1196. >>> # Add downloaded bytes column
  1197. >>> ds_with_bytes = ds.with_column("bytes", download("uri"))
  1198. """
  1199. return DownloadExpr(uri_column_name=uri_column_name, filesystem=filesystem)
  1200. # ──────────────────────────────────────
  1201. # Public API for evaluation
  1202. # ──────────────────────────────────────
  1203. # Note: Implementation details are in _expression_evaluator.py
  1204. # Re-export eval_expr for public use
  1205. __all__ = [
  1206. "Operation",
  1207. "Expr",
  1208. "ColumnExpr",
  1209. "LiteralExpr",
  1210. "BinaryExpr",
  1211. "UnaryExpr",
  1212. "UDFExpr",
  1213. "DownloadExpr",
  1214. "AliasExpr",
  1215. "StarExpr",
  1216. "pyarrow_udf",
  1217. "udf",
  1218. "col",
  1219. "lit",
  1220. "download",
  1221. "star",
  1222. "_ArrayNamespace",
  1223. "_ListNamespace",
  1224. "_StringNamespace",
  1225. "_StructNamespace",
  1226. "_DatetimeNamespace",
  1227. ]
  1228. def __getattr__(name: str):
  1229. """Lazy import of namespace classes to avoid circular imports."""
  1230. if name == "_ArrayNamespace":
  1231. from ray.data.namespace_expressions.arr_namespace import _ArrayNamespace
  1232. return _ArrayNamespace
  1233. elif name == "_ListNamespace":
  1234. from ray.data.namespace_expressions.list_namespace import _ListNamespace
  1235. return _ListNamespace
  1236. elif name == "_StringNamespace":
  1237. from ray.data.namespace_expressions.string_namespace import _StringNamespace
  1238. return _StringNamespace
  1239. elif name == "_StructNamespace":
  1240. from ray.data.namespace_expressions.struct_namespace import _StructNamespace
  1241. return _StructNamespace
  1242. elif name == "_DatetimeNamespace":
  1243. from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
  1244. return _DatetimeNamespace
  1245. raise AttributeError(f"module {__name__!r} has no attribute {name!r}")