string_namespace.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """String namespace for expression operations on string-typed columns."""
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, Any, Callable, Literal
  5. import pyarrow
  6. import pyarrow.compute as pc
  7. from ray.data.datatype import DataType
  8. from ray.data.expressions import _create_pyarrow_compute_udf, pyarrow_udf
  9. if TYPE_CHECKING:
  10. from ray.data.expressions import Expr, UDFExpr
  11. def _create_str_udf(
  12. pc_func: Callable[..., pyarrow.Array], return_dtype: DataType
  13. ) -> Callable[..., "UDFExpr"]:
  14. """Helper to create a string UDF that wraps a PyArrow compute function.
  15. This helper handles all types of PyArrow compute operations:
  16. - Unary operations (no args): upper(), lower(), reverse()
  17. - Pattern operations (pattern + args): starts_with(), contains()
  18. - Multi-argument operations: replace(), replace_slice()
  19. Args:
  20. pc_func: PyArrow compute function that takes (array, *positional, **kwargs)
  21. return_dtype: The return data type
  22. Returns:
  23. A callable that creates UDFExpr instances
  24. """
  25. return _create_pyarrow_compute_udf(pc_func, return_dtype=return_dtype)
  26. @dataclass
  27. class _StringNamespace:
  28. """Namespace for string operations on expression columns.
  29. This namespace provides methods for operating on string-typed columns using
  30. PyArrow compute functions.
  31. Example:
  32. >>> from ray.data.expressions import col
  33. >>> # Convert to uppercase
  34. >>> expr = col("name").str.upper()
  35. >>> # Get string length
  36. >>> expr = col("name").str.len()
  37. >>> # Check if string starts with a prefix
  38. >>> expr = col("name").str.starts_with("A")
  39. """
  40. _expr: Expr
  41. # Length methods
  42. def len(self) -> "UDFExpr":
  43. """Get the length of each string in characters."""
  44. return _create_str_udf(pc.utf8_length, DataType.int32())(self._expr)
  45. def byte_len(self) -> "UDFExpr":
  46. """Get the length of each string in bytes."""
  47. return _create_str_udf(pc.binary_length, DataType.int32())(self._expr)
  48. # Case methods
  49. def upper(self) -> "UDFExpr":
  50. """Convert strings to uppercase."""
  51. return _create_str_udf(pc.utf8_upper, DataType.string())(self._expr)
  52. def lower(self) -> "UDFExpr":
  53. """Convert strings to lowercase."""
  54. return _create_str_udf(pc.utf8_lower, DataType.string())(self._expr)
  55. def capitalize(self) -> "UDFExpr":
  56. """Capitalize the first character of each string."""
  57. return _create_str_udf(pc.utf8_capitalize, DataType.string())(self._expr)
  58. def title(self) -> "UDFExpr":
  59. """Convert strings to title case."""
  60. return _create_str_udf(pc.utf8_title, DataType.string())(self._expr)
  61. def swapcase(self) -> "UDFExpr":
  62. """Swap the case of each character."""
  63. return _create_str_udf(pc.utf8_swapcase, DataType.string())(self._expr)
  64. # Predicate methods
  65. def is_alpha(self) -> "UDFExpr":
  66. """Check if strings contain only alphabetic characters."""
  67. return _create_str_udf(pc.utf8_is_alpha, DataType.bool())(self._expr)
  68. def is_alnum(self) -> "UDFExpr":
  69. """Check if strings contain only alphanumeric characters."""
  70. return _create_str_udf(pc.utf8_is_alnum, DataType.bool())(self._expr)
  71. def is_digit(self) -> "UDFExpr":
  72. """Check if strings contain only digits."""
  73. return _create_str_udf(pc.utf8_is_digit, DataType.bool())(self._expr)
  74. def is_decimal(self) -> "UDFExpr":
  75. """Check if strings contain only decimal characters."""
  76. return _create_str_udf(pc.utf8_is_decimal, DataType.bool())(self._expr)
  77. def is_numeric(self) -> "UDFExpr":
  78. """Check if strings contain only numeric characters."""
  79. return _create_str_udf(pc.utf8_is_numeric, DataType.bool())(self._expr)
  80. def is_space(self) -> "UDFExpr":
  81. """Check if strings contain only whitespace."""
  82. return _create_str_udf(pc.utf8_is_space, DataType.bool())(self._expr)
  83. def is_lower(self) -> "UDFExpr":
  84. """Check if strings are lowercase."""
  85. return _create_str_udf(pc.utf8_is_lower, DataType.bool())(self._expr)
  86. def is_upper(self) -> "UDFExpr":
  87. """Check if strings are uppercase."""
  88. return _create_str_udf(pc.utf8_is_upper, DataType.bool())(self._expr)
  89. def is_title(self) -> "UDFExpr":
  90. """Check if strings are title-cased."""
  91. return _create_str_udf(pc.utf8_is_title, DataType.bool())(self._expr)
  92. def is_printable(self) -> "UDFExpr":
  93. """Check if strings contain only printable characters."""
  94. return _create_str_udf(pc.utf8_is_printable, DataType.bool())(self._expr)
  95. def is_ascii(self) -> "UDFExpr":
  96. """Check if strings contain only ASCII characters."""
  97. return _create_str_udf(pc.string_is_ascii, DataType.bool())(self._expr)
  98. # Searching methods
  99. def starts_with(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  100. """Check if strings start with a pattern."""
  101. return _create_str_udf(pc.starts_with, DataType.bool())(
  102. self._expr, pattern, *args, **kwargs
  103. )
  104. def ends_with(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  105. """Check if strings end with a pattern."""
  106. return _create_str_udf(pc.ends_with, DataType.bool())(
  107. self._expr, pattern, *args, **kwargs
  108. )
  109. def contains(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  110. """Check if strings contain a substring."""
  111. return _create_str_udf(pc.match_substring, DataType.bool())(
  112. self._expr, pattern, *args, **kwargs
  113. )
  114. def match(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  115. """Match strings against a SQL LIKE pattern."""
  116. return _create_str_udf(pc.match_like, DataType.bool())(
  117. self._expr, pattern, *args, **kwargs
  118. )
  119. def find(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  120. """Find the first occurrence of a substring."""
  121. return _create_str_udf(pc.find_substring, DataType.int32())(
  122. self._expr, pattern, *args, **kwargs
  123. )
  124. def count(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  125. """Count occurrences of a substring."""
  126. return _create_str_udf(pc.count_substring, DataType.int32())(
  127. self._expr, pattern, *args, **kwargs
  128. )
  129. def find_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  130. """Find the first occurrence matching a regex pattern."""
  131. return _create_str_udf(pc.find_substring_regex, DataType.int32())(
  132. self._expr, pattern, *args, **kwargs
  133. )
  134. def count_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  135. """Count occurrences matching a regex pattern."""
  136. return _create_str_udf(pc.count_substring_regex, DataType.int32())(
  137. self._expr, pattern, *args, **kwargs
  138. )
  139. def match_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  140. """Check if strings match a regex pattern."""
  141. return _create_str_udf(pc.match_substring_regex, DataType.bool())(
  142. self._expr, pattern, *args, **kwargs
  143. )
  144. # Transformation methods
  145. def reverse(self) -> "UDFExpr":
  146. """Reverse each string."""
  147. return _create_str_udf(pc.utf8_reverse, DataType.string())(self._expr)
  148. def slice(self, *args: Any, **kwargs: Any) -> "UDFExpr":
  149. """Slice strings by codeunit indices."""
  150. return _create_str_udf(pc.utf8_slice_codeunits, DataType.string())(
  151. self._expr, *args, **kwargs
  152. )
  153. def replace(
  154. self, pattern: str, replacement: str, *args: Any, **kwargs: Any
  155. ) -> "UDFExpr":
  156. """Replace occurrences of a substring."""
  157. return _create_str_udf(pc.replace_substring, DataType.string())(
  158. self._expr, pattern, replacement, *args, **kwargs
  159. )
  160. def replace_regex(
  161. self, pattern: str, replacement: str, *args: Any, **kwargs: Any
  162. ) -> "UDFExpr":
  163. """Replace occurrences matching a regex pattern."""
  164. return _create_str_udf(pc.replace_substring_regex, DataType.string())(
  165. self._expr, pattern, replacement, *args, **kwargs
  166. )
  167. def replace_slice(
  168. self, start: int, stop: int, replacement: str, *args: Any, **kwargs: Any
  169. ) -> "UDFExpr":
  170. """Replace a slice with a string."""
  171. return _create_str_udf(pc.binary_replace_slice, DataType.string())(
  172. self._expr, start, stop, replacement, *args, **kwargs
  173. )
  174. def split(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  175. """Split strings by a pattern."""
  176. return _create_str_udf(pc.split_pattern, DataType(object))(
  177. self._expr, pattern, *args, **kwargs
  178. )
  179. def split_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  180. """Split strings by a regex pattern."""
  181. return _create_str_udf(pc.split_pattern_regex, DataType(object))(
  182. self._expr, pattern, *args, **kwargs
  183. )
  184. def split_whitespace(self, *args: Any, **kwargs: Any) -> "UDFExpr":
  185. """Split strings on whitespace."""
  186. return _create_str_udf(pc.utf8_split_whitespace, DataType(object))(
  187. self._expr, *args, **kwargs
  188. )
  189. def extract(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr":
  190. """Extract a substring matching a regex pattern."""
  191. return _create_str_udf(pc.extract_regex, DataType.string())(
  192. self._expr, pattern, *args, **kwargs
  193. )
  194. def repeat(self, n: int, *args: Any, **kwargs: Any) -> "UDFExpr":
  195. """Repeat each string n times."""
  196. return _create_str_udf(pc.binary_repeat, DataType.string())(
  197. self._expr, n, *args, **kwargs
  198. )
  199. def center(
  200. self, width: int, padding: str = " ", *args: Any, **kwargs: Any
  201. ) -> "UDFExpr":
  202. """Center strings in a field of given width."""
  203. return _create_str_udf(pc.utf8_center, DataType.string())(
  204. self._expr, width, padding, *args, **kwargs
  205. )
  206. def lpad(
  207. self, width: int, padding: str = " ", *args: Any, **kwargs: Any
  208. ) -> "UDFExpr":
  209. """Right-align strings by padding with a given character while respecting ``width``.
  210. If the string is longer than the specified width, it remains intact (no truncation occurs).
  211. """
  212. return _create_str_udf(pc.utf8_lpad, DataType.string())(
  213. self._expr, width, padding, *args, **kwargs
  214. )
  215. def rpad(
  216. self, width: int, padding: str = " ", *args: Any, **kwargs: Any
  217. ) -> "UDFExpr":
  218. """Left-align strings by padding with a given character while respecting ``width``.
  219. If the string is longer than the specified width, it remains intact (no truncation occurs).
  220. """
  221. return _create_str_udf(pc.utf8_rpad, DataType.string())(
  222. self._expr, width, padding, *args, **kwargs
  223. )
  224. # Custom methods that need special logic beyond simple PyArrow function calls
  225. def strip(self, characters: str | None = None) -> "UDFExpr":
  226. """Remove leading and trailing whitespace or specified characters.
  227. Args:
  228. characters: Characters to remove. If None, removes whitespace.
  229. Returns:
  230. UDFExpr that strips characters from both ends.
  231. """
  232. @pyarrow_udf(return_dtype=DataType.string())
  233. def _str_strip(arr: pyarrow.Array) -> pyarrow.Array:
  234. if characters is None:
  235. return pc.utf8_trim_whitespace(arr)
  236. else:
  237. return pc.utf8_trim(arr, characters=characters)
  238. return _str_strip(self._expr)
  239. def lstrip(self, characters: str | None = None) -> "UDFExpr":
  240. """Remove leading whitespace or specified characters.
  241. Args:
  242. characters: Characters to remove. If None, removes whitespace.
  243. Returns:
  244. UDFExpr that strips characters from the left.
  245. """
  246. @pyarrow_udf(return_dtype=DataType.string())
  247. def _str_lstrip(arr: pyarrow.Array) -> pyarrow.Array:
  248. if characters is None:
  249. return pc.utf8_ltrim_whitespace(arr)
  250. else:
  251. return pc.utf8_ltrim(arr, characters=characters)
  252. return _str_lstrip(self._expr)
  253. def rstrip(self, characters: str | None = None) -> "UDFExpr":
  254. """Remove trailing whitespace or specified characters.
  255. Args:
  256. characters: Characters to remove. If None, removes whitespace.
  257. Returns:
  258. UDFExpr that strips characters from the right.
  259. """
  260. @pyarrow_udf(return_dtype=DataType.string())
  261. def _str_rstrip(arr: pyarrow.Array) -> pyarrow.Array:
  262. if characters is None:
  263. return pc.utf8_rtrim_whitespace(arr)
  264. else:
  265. return pc.utf8_rtrim(arr, characters=characters)
  266. return _str_rstrip(self._expr)
  267. # Padding
  268. def pad(
  269. self,
  270. width: int,
  271. fillchar: str = " ",
  272. side: Literal["left", "right", "both"] = "right",
  273. ) -> "UDFExpr":
  274. """Pad strings to a specified width.
  275. Args:
  276. width: Target width.
  277. fillchar: Character to use for padding.
  278. side: "left", "right", or "both" for padding side.
  279. Returns:
  280. UDFExpr that pads strings.
  281. """
  282. @pyarrow_udf(return_dtype=DataType.string())
  283. def _str_pad(arr: pyarrow.Array) -> pyarrow.Array:
  284. if side == "right":
  285. return pc.utf8_rpad(arr, width=width, padding=fillchar)
  286. elif side == "left":
  287. return pc.utf8_lpad(arr, width=width, padding=fillchar)
  288. elif side == "both":
  289. return pc.utf8_center(arr, width=width, padding=fillchar)
  290. else:
  291. raise ValueError("side must be 'left', 'right', or 'both'")
  292. return _str_pad(self._expr)