expression_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """Utility functions for expression-based operations."""
  2. from typing import TYPE_CHECKING, Any, Callable, List, Optional
  3. if TYPE_CHECKING:
  4. from ray.data.expressions import Expr
  5. def create_callable_class_udf_init_fn(
  6. exprs: List["Expr"],
  7. ) -> Optional[Callable[[], None]]:
  8. """Create an init_fn to initialize all callable class UDFs in expressions.
  9. This function collects all _CallableClassUDF instances from the given expressions,
  10. groups them by their callable_class_spec key, and returns an init_fn that
  11. initializes each group at actor startup. UDFs with the same key (same class and
  12. constructor args) share a single instance to ensure all are properly initialized.
  13. Args:
  14. exprs: List of expressions to collect callable class UDFs from.
  15. Returns:
  16. An init_fn that initializes all callable class UDFs, or None if there are
  17. no callable class UDFs in the expressions.
  18. """
  19. from ray.data._internal.planner.plan_expression.expression_visitors import (
  20. _CallableClassUDFCollector,
  21. )
  22. callable_class_udfs = []
  23. for expr in exprs:
  24. collector = _CallableClassUDFCollector()
  25. collector.visit(expr)
  26. callable_class_udfs.extend(collector.get_callable_class_udfs())
  27. if not callable_class_udfs:
  28. return None
  29. # Group UDFs by callable_class_spec key.
  30. # Multiple _CallableClassUDF objects may have the same key (same class + args).
  31. # We need to initialize ALL of them, sharing a single instance per key.
  32. udfs_by_key = {}
  33. for udf in callable_class_udfs:
  34. key = udf.callable_class_spec.make_key()
  35. if key not in udfs_by_key:
  36. udfs_by_key[key] = []
  37. udfs_by_key[key].append(udf)
  38. def init_fn():
  39. for udfs_with_same_key in udfs_by_key.values():
  40. # Initialize the first UDF to create the instance
  41. first_udf = udfs_with_same_key[0]
  42. first_udf.init()
  43. # Share the instance with all other UDFs that have the same key
  44. for other_udf in udfs_with_same_key[1:]:
  45. other_udf._instance = first_udf._instance
  46. return init_fn
  47. def _call_udf_instance_with_async_bridge(
  48. instance: Any,
  49. *args,
  50. **kwargs,
  51. ) -> Any:
  52. """Call a UDF instance, bridging from sync context to async if needed.
  53. This handles the complexity of calling callable class UDF instances that may
  54. be sync, async coroutine, or async generator functions.
  55. Args:
  56. instance: The callable instance to call
  57. *args: Positional arguments
  58. **kwargs: Keyword arguments
  59. Returns:
  60. The result of calling the instance
  61. """
  62. import asyncio
  63. import inspect
  64. # Check if the instance's __call__ is async
  65. if inspect.iscoroutinefunction(instance.__call__):
  66. # Async coroutine: bridge from sync to async
  67. return asyncio.run(instance(*args, **kwargs))
  68. elif inspect.isasyncgenfunction(instance.__call__):
  69. # Async generator: collect results
  70. async def _collect():
  71. results = []
  72. async for item in instance(*args, **kwargs):
  73. results.append(item)
  74. # In expressions, the UDF must return a single array with the same
  75. # length as the input (one output element per input row).
  76. # If the async generator yields multiple arrays, we take the last one
  77. # since expressions don't support multi-batch output semantics.
  78. if not results:
  79. return None
  80. elif len(results) == 1:
  81. return results[0]
  82. else:
  83. import logging
  84. logging.warning(
  85. f"Async generator yielded {len(results)} values in expression context; "
  86. "only the last (most recent) is returned. Use map_batches for multi-yield support."
  87. )
  88. return results[-1]
  89. return asyncio.run(_collect())
  90. else:
  91. # Synchronous instance - direct call
  92. return instance(*args, **kwargs)