util.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794
  1. import functools
  2. import importlib
  3. import logging
  4. import os
  5. import pathlib
  6. import platform
  7. import random
  8. import sys
  9. import threading
  10. import time
  11. import urllib.parse
  12. import uuid
  13. from queue import Empty, Full, Queue
  14. from types import ModuleType
  15. from typing import (
  16. TYPE_CHECKING,
  17. Any,
  18. Callable,
  19. Dict,
  20. Generator,
  21. Iterable,
  22. Iterator,
  23. List,
  24. Optional,
  25. Tuple,
  26. TypeVar,
  27. Union,
  28. overload,
  29. )
  30. import numpy as np
  31. import pandas as pd
  32. # NOTE: pyarrow.fs module needs to be explicitly imported!
  33. import pyarrow
  34. import pyarrow.fs
  35. import ray
  36. from ray._common.retry import call_with_retry
  37. from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext
  38. from ray.util.annotations import DeveloperAPI
  39. import psutil
  40. # TypeVar for preserving function/class signatures through decorators
  41. F = TypeVar("F", bound=Callable[..., Any])
  42. if TYPE_CHECKING:
  43. import pandas
  44. from ray.data._internal.compute import ComputeStrategy
  45. from ray.data._internal.execution.interfaces import RefBundle
  46. from ray.data._internal.planner.exchange.sort_task_spec import SortKey
  47. from ray.data.block import (
  48. Block,
  49. BlockMetadataWithSchema,
  50. Schema,
  51. UserDefinedFunction,
  52. )
  53. from ray.data.datasource import Datasource, Reader
  54. from ray.util.placement_group import PlacementGroup
  55. logger = logging.getLogger(__name__)
  56. KiB = 1024 # bytes
  57. MiB = 1024 * KiB
  58. GiB = 1024 * MiB
  59. SENTINEL = object()
  60. _LOCAL_SCHEME = "local"
  61. _EXAMPLE_SCHEME = "example"
  62. LazyModule = Union[None, bool, ModuleType]
  63. _pyarrow_dataset: LazyModule = None
  64. class _OrderedNullSentinel:
  65. """Sentinel value that sorts greater than any other non-null value.
  66. NOTE: Semantic of this sentinel is closely mirroring that one of
  67. ``np.nan`` for the purpose of consistency in handling of
  68. ``None``s and ``np.nan``s.
  69. """
  70. def __eq__(self, other):
  71. return False
  72. def __lt__(self, other):
  73. # not None < _OrderedNullSentinel
  74. # _OrderedNullSentinel < _OrderedNullSentinel
  75. # _OrderedNullSentinel < None
  76. # _OrderedNullSentinel < np.nan
  77. return isinstance(other, _OrderedNullSentinel) or is_null(other)
  78. def __le__(self, other):
  79. # NOTE: This is just a shortened version of
  80. # self < other or self == other
  81. return self.__lt__(other)
  82. def __gt__(self, other):
  83. return not self.__le__(other)
  84. def __ge__(self, other):
  85. return not self.__lt__(other)
  86. def __hash__(self):
  87. return id(self)
  88. NULL_SENTINEL = _OrderedNullSentinel()
  89. def _lazy_import_pyarrow_dataset() -> LazyModule:
  90. global _pyarrow_dataset
  91. if _pyarrow_dataset is None:
  92. try:
  93. from pyarrow import dataset as _pyarrow_dataset
  94. except ModuleNotFoundError:
  95. # If module is not found, set _pyarrow to False so we won't
  96. # keep trying to import it on every _lazy_import_pyarrow() call.
  97. _pyarrow_dataset = False
  98. return _pyarrow_dataset
  99. def _check_pyarrow_version():
  100. ray.data._internal.utils.arrow_utils._check_pyarrow_version()
  101. def _autodetect_parallelism(
  102. parallelism: int,
  103. target_max_block_size: Optional[int],
  104. ctx: DataContext,
  105. datasource_or_legacy_reader: Optional[Union["Datasource", "Reader"]] = None,
  106. mem_size: Optional[int] = None,
  107. placement_group: Optional["PlacementGroup"] = None,
  108. avail_cpus: Optional[int] = None,
  109. ) -> Tuple[int, str, Optional[int]]:
  110. """Returns parallelism to use and the min safe parallelism to avoid OOMs.
  111. This detects parallelism using the following heuristics, applied in order:
  112. 1) We start with the default value of 200. This can be overridden by
  113. setting the `read_op_min_num_blocks` attribute of
  114. :class:`~ray.data.context.DataContext`.
  115. 2) Min block size. If the parallelism would make blocks smaller than this
  116. threshold, the parallelism is reduced to avoid the overhead of tiny blocks.
  117. 3) Max block size. If the parallelism would make blocks larger than this
  118. threshold, the parallelism is increased to avoid OOMs during processing.
  119. 4) Available CPUs. If the parallelism cannot make use of all the available
  120. CPUs in the cluster, the parallelism is increased until it can.
  121. Args:
  122. parallelism: The user-requested parallelism, or -1 for auto-detection.
  123. target_max_block_size: The target max block size to
  124. produce. We pass this separately from the
  125. DatasetContext because it may be set per-op instead of
  126. per-Dataset.
  127. ctx: The current Dataset context to use for configs.
  128. datasource_or_legacy_reader: The datasource or legacy reader, to be used for
  129. data size estimation.
  130. mem_size: If passed, then used to compute the parallelism according to
  131. target_max_block_size.
  132. placement_group: The placement group that this Dataset
  133. will execute inside, if any.
  134. avail_cpus: Override avail cpus detection (for testing only).
  135. Returns:
  136. Tuple of detected parallelism (only if -1 was specified), the reason
  137. for the detected parallelism (only if -1 was specified), and the estimated
  138. inmemory size of the dataset.
  139. """
  140. min_safe_parallelism = 1
  141. max_reasonable_parallelism = sys.maxsize
  142. if mem_size is None and datasource_or_legacy_reader:
  143. mem_size = datasource_or_legacy_reader.estimate_inmemory_data_size()
  144. if (
  145. mem_size is not None
  146. and not np.isnan(mem_size)
  147. and target_max_block_size is not None
  148. ):
  149. min_safe_parallelism = max(1, int(mem_size / target_max_block_size))
  150. max_reasonable_parallelism = max(1, int(mem_size / ctx.target_min_block_size))
  151. reason = ""
  152. if parallelism < 0:
  153. if parallelism != -1:
  154. raise ValueError("`parallelism` must either be -1 or a positive integer.")
  155. if (
  156. ctx.min_parallelism is not None
  157. and ctx.min_parallelism != DEFAULT_READ_OP_MIN_NUM_BLOCKS
  158. and ctx.read_op_min_num_blocks == DEFAULT_READ_OP_MIN_NUM_BLOCKS
  159. ):
  160. logger.warning(
  161. "``DataContext.min_parallelism`` is deprecated in Ray 2.10. "
  162. "Please specify ``DataContext.read_op_min_num_blocks`` instead."
  163. )
  164. ctx.read_op_min_num_blocks = ctx.min_parallelism
  165. # Start with 2x the number of cores as a baseline, with a min floor.
  166. if placement_group is None:
  167. placement_group = ray.util.get_current_placement_group()
  168. avail_cpus = avail_cpus or _estimate_avail_cpus(placement_group)
  169. parallelism = max(
  170. min(ctx.read_op_min_num_blocks, max_reasonable_parallelism),
  171. min_safe_parallelism,
  172. avail_cpus * 2,
  173. )
  174. if parallelism == ctx.read_op_min_num_blocks:
  175. reason = (
  176. "DataContext.get_current().read_op_min_num_blocks="
  177. f"{ctx.read_op_min_num_blocks}"
  178. )
  179. elif parallelism == max_reasonable_parallelism:
  180. reason = (
  181. "output blocks of size at least "
  182. "DataContext.get_current().target_min_block_size="
  183. f"{ctx.target_min_block_size / MiB} MiB"
  184. )
  185. elif parallelism == min_safe_parallelism:
  186. # Handle ``None`` (unlimited) gracefully in the log message.
  187. if ctx.target_max_block_size is None:
  188. display_val = "unlimited"
  189. else:
  190. display_val = f"{ctx.target_max_block_size / MiB} MiB"
  191. reason = (
  192. "output blocks of size at most "
  193. "DataContext.get_current().target_max_block_size="
  194. f"{display_val}"
  195. )
  196. else:
  197. reason = (
  198. "parallelism at least twice the available number "
  199. f"of CPUs ({avail_cpus})"
  200. )
  201. logger.debug(
  202. f"Autodetected parallelism={parallelism} based on "
  203. f"estimated_available_cpus={avail_cpus} and "
  204. f"estimated_data_size={mem_size}."
  205. )
  206. return parallelism, reason, mem_size
  207. def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int:
  208. """Estimates the available CPU parallelism for this Dataset in the cluster.
  209. If we aren't in a placement group, this is trivially the number of CPUs in the
  210. cluster. Otherwise, we try to calculate how large the placement group is relative
  211. to the size of the cluster.
  212. Args:
  213. cur_pg: The current placement group, if any.
  214. """
  215. cluster_cpus = int(ray.cluster_resources().get("CPU", 1))
  216. cluster_gpus = int(ray.cluster_resources().get("GPU", 0))
  217. # If we're in a placement group, we shouldn't assume the entire cluster's
  218. # resources are available for us to use. Estimate an upper bound on what's
  219. # reasonable to assume is available for datasets to use.
  220. if cur_pg:
  221. pg_cpus = 0
  222. for bundle in cur_pg.bundle_specs:
  223. # Calculate the proportion of the cluster this placement group "takes up".
  224. # Then scale our cluster_cpus proportionally to avoid over-parallelizing
  225. # if there are many parallel Tune trials using the cluster.
  226. cpu_fraction = bundle.get("CPU", 0) / max(1, cluster_cpus)
  227. gpu_fraction = bundle.get("GPU", 0) / max(1, cluster_gpus)
  228. max_fraction = max(cpu_fraction, gpu_fraction)
  229. # Over-parallelize by up to a factor of 2, but no more than that. It's
  230. # preferrable to over-estimate than under-estimate.
  231. pg_cpus += 2 * int(max_fraction * cluster_cpus)
  232. return min(cluster_cpus, pg_cpus)
  233. return cluster_cpus
  234. def _estimate_available_parallelism() -> int:
  235. """Estimates the available CPU parallelism for this Dataset in the cluster.
  236. If we are currently in a placement group, take that into account."""
  237. cur_pg = ray.util.get_current_placement_group()
  238. return _estimate_avail_cpus(cur_pg)
  239. def _warn_on_high_parallelism(requested_parallelism, num_read_tasks):
  240. available_cpu_slots = ray.available_resources().get("CPU", 1)
  241. if (
  242. requested_parallelism
  243. and num_read_tasks > available_cpu_slots * 4
  244. and num_read_tasks >= 5000
  245. ):
  246. logger.warning(
  247. f"{WARN_PREFIX} The requested parallelism of {requested_parallelism} "
  248. "is more than 4x the number of available CPU slots in the cluster of "
  249. f"{available_cpu_slots}. This can "
  250. "lead to slowdowns during the data reading phase due to excessive "
  251. "task creation. Reduce the parallelism to match with the available "
  252. "CPU slots in the cluster, or set parallelism to -1 for Ray Data "
  253. "to automatically determine the parallelism. "
  254. "You can ignore this message if the cluster is expected to autoscale."
  255. )
  256. def _check_import(obj, *, module: str, package: str) -> None:
  257. """Check if a required dependency is installed.
  258. If `module` can't be imported, this function raises an `ImportError` instructing
  259. the user to install `package` from PyPI.
  260. Args:
  261. obj: The object that has a dependency.
  262. module: The name of the module to import.
  263. package: The name of the package on PyPI.
  264. """
  265. try:
  266. importlib.import_module(module)
  267. except ImportError:
  268. raise ImportError(
  269. f"`{obj.__class__.__name__}` depends on '{module}', but Ray Data couldn't "
  270. f"import it. Install '{module}' by running `pip install {package}`."
  271. )
  272. def _resolve_custom_scheme(path: str) -> str:
  273. """Returns the resolved path if the given path follows a Ray-specific custom
  274. scheme. Othewise, returns the path unchanged.
  275. The supported custom schemes are: "local", "example".
  276. """
  277. parsed_uri = urllib.parse.urlparse(path)
  278. if parsed_uri.scheme == _LOCAL_SCHEME:
  279. path = parsed_uri.netloc + parsed_uri.path
  280. elif parsed_uri.scheme == _EXAMPLE_SCHEME:
  281. example_data_path = pathlib.Path(__file__).parent.parent / "examples" / "data"
  282. path = example_data_path / (parsed_uri.netloc + parsed_uri.path)
  283. path = str(path.resolve())
  284. return path
  285. def _is_local_scheme(paths: Union[str, List[str]]) -> bool:
  286. """Returns True if the given paths are in local scheme.
  287. Note: The paths must be in same scheme, i.e. it's invalid and
  288. will raise error if paths are mixed with different schemes.
  289. """
  290. if isinstance(paths, str):
  291. paths = [paths]
  292. if isinstance(paths, pathlib.Path):
  293. paths = [str(paths)]
  294. elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
  295. raise ValueError("paths must be a path string or a list of path strings.")
  296. elif len(paths) == 0:
  297. raise ValueError("Must provide at least one path.")
  298. num = sum(urllib.parse.urlparse(path).scheme == _LOCAL_SCHEME for path in paths)
  299. if num > 0 and num < len(paths):
  300. raise ValueError(
  301. "The paths must all be local-scheme or not local-scheme, "
  302. f"but found mixed {paths}"
  303. )
  304. return num == len(paths)
  305. def _truncated_repr(obj: Any) -> str:
  306. """Utility to return a truncated object representation for error messages."""
  307. msg = str(obj)
  308. if len(msg) > 200:
  309. msg = msg[:200] + "..."
  310. return msg
  311. def _insert_doc_at_pattern(
  312. obj,
  313. *,
  314. message: str,
  315. pattern: str,
  316. insert_after: bool = True,
  317. directive: Optional[str] = None,
  318. skip_matches: int = 0,
  319. ) -> str:
  320. if "\n" in message:
  321. raise ValueError(
  322. "message shouldn't contain any newlines, since this function will insert "
  323. f"its own linebreaks when text wrapping: {message}"
  324. )
  325. doc = obj.__doc__.strip()
  326. if not doc:
  327. doc = ""
  328. if pattern == "" and insert_after:
  329. # Empty pattern + insert_after means that we want to append the message to the
  330. # end of the docstring.
  331. head = doc
  332. tail = ""
  333. else:
  334. tail = doc
  335. i = tail.find(pattern)
  336. skip_matches_left = skip_matches
  337. while i != -1:
  338. if insert_after:
  339. # Set offset to the first character after the pattern.
  340. offset = i + len(pattern)
  341. else:
  342. # Set offset to the first character in the matched line.
  343. offset = tail[:i].rfind("\n") + 1
  344. head = tail[:offset]
  345. tail = tail[offset:]
  346. skip_matches_left -= 1
  347. if skip_matches_left <= 0:
  348. break
  349. elif not insert_after:
  350. # Move past the found pattern, since we're skipping it.
  351. tail = tail[i - offset + len(pattern) :]
  352. i = tail.find(pattern)
  353. else:
  354. raise ValueError(
  355. f"Pattern {pattern} not found after {skip_matches} skips in docstring "
  356. f"{doc}"
  357. )
  358. # Get indentation of the to-be-inserted text.
  359. after_lines = list(filter(bool, tail.splitlines()))
  360. if len(after_lines) > 0:
  361. lines = after_lines
  362. else:
  363. lines = list(filter(bool, reversed(head.splitlines())))
  364. # Should always have at least one non-empty line in the docstring.
  365. assert len(lines) > 0
  366. indent = " " * (len(lines[0]) - len(lines[0].lstrip()))
  367. # Handle directive.
  368. message = message.strip("\n")
  369. if directive is not None:
  370. base = f"{indent}.. {directive}::\n"
  371. message = message.replace("\n", "\n" + indent + " " * 4)
  372. message = base + indent + " " * 4 + message
  373. else:
  374. message = indent + message.replace("\n", "\n" + indent)
  375. # Add two blank lines before/after message, if necessary.
  376. if insert_after ^ (pattern == "\n\n"):
  377. # Only two blank lines before message if:
  378. # 1. Inserting message after pattern and pattern is not two blank lines.
  379. # 2. Inserting message before pattern and pattern is two blank lines.
  380. message = "\n\n" + message
  381. if (not insert_after) ^ (pattern == "\n\n"):
  382. # Only two blank lines after message if:
  383. # 1. Inserting message before pattern and pattern is not two blank lines.
  384. # 2. Inserting message after pattern and pattern is two blank lines.
  385. message = message + "\n\n"
  386. # Insert message before/after pattern.
  387. parts = [head, message, tail]
  388. # Build new docstring.
  389. obj.__doc__ = "".join(parts)
  390. def _consumption_api(
  391. if_more_than_read: bool = False,
  392. datasource_metadata: Optional[str] = None,
  393. extra_condition: Optional[str] = None,
  394. delegate: Optional[str] = None,
  395. pattern: str = "Examples:",
  396. insert_after: bool = False,
  397. ) -> Callable[[F], F]:
  398. """Annotate the function with an indication that it's a consumption API, and that it
  399. will trigger Dataset execution.
  400. """
  401. base = (
  402. " will trigger execution of the lazy transformations performed on "
  403. "this dataset."
  404. )
  405. if delegate:
  406. message = delegate + base
  407. elif not if_more_than_read:
  408. message = "This operation" + base
  409. else:
  410. condition = "If this dataset consists of more than a read, "
  411. if datasource_metadata is not None:
  412. condition += (
  413. f"or if the {datasource_metadata} can't be determined from the "
  414. "metadata provided by the datasource, "
  415. )
  416. if extra_condition is not None:
  417. condition += extra_condition + ", "
  418. message = condition + "then this operation" + base
  419. def wrap(obj: F) -> F:
  420. _insert_doc_at_pattern(
  421. obj,
  422. message=message,
  423. pattern=pattern,
  424. insert_after=insert_after,
  425. directive="note",
  426. )
  427. return obj
  428. return wrap
  429. @overload
  430. def ConsumptionAPI(obj: F) -> F:
  431. ...
  432. @overload
  433. def ConsumptionAPI(
  434. *,
  435. if_more_than_read: bool = False,
  436. datasource_metadata: Optional[str] = None,
  437. extra_condition: Optional[str] = None,
  438. delegate: Optional[str] = None,
  439. ) -> Callable[[F], F]:
  440. ...
  441. def ConsumptionAPI(*args, **kwargs):
  442. """Annotate the function with an indication that it's a consumption API, and that it
  443. will trigger Dataset execution.
  444. """
  445. if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
  446. return _consumption_api()(args[0])
  447. return _consumption_api(*args, **kwargs)
  448. def _all_to_all_api() -> Callable[[F], F]:
  449. """Annotate the function with an indication that it's a all to all API, and that it
  450. is an operation that requires all inputs to be materialized in-memory to execute.
  451. """
  452. def wrap(obj: F) -> F:
  453. _insert_doc_at_pattern(
  454. obj,
  455. message=(
  456. "This operation requires all inputs to be "
  457. "materialized in object store for it to execute."
  458. ),
  459. pattern="Examples:",
  460. insert_after=False,
  461. directive="note",
  462. )
  463. return obj
  464. return wrap
  465. @overload
  466. def AllToAllAPI(obj: F) -> F:
  467. ...
  468. def AllToAllAPI(*args, **kwargs):
  469. """Annotate the function with an indication that it's a all to all API, and that it
  470. is an operation that requires all inputs to be materialized in-memory to execute.
  471. """
  472. # This should only be used as a decorator for dataset methods.
  473. assert len(args) == 1 and len(kwargs) == 0 and callable(args[0])
  474. return _all_to_all_api()(args[0])
  475. def get_compute_strategy(
  476. fn: "UserDefinedFunction",
  477. fn_constructor_args: Optional[Iterable[Any]] = None,
  478. compute: Optional[Union[str, "ComputeStrategy"]] = None,
  479. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  480. ) -> "ComputeStrategy":
  481. """Get `ComputeStrategy` based on the function or class, and concurrency
  482. information.
  483. Args:
  484. fn: The function or generator to apply to a record batch, or a class type
  485. that can be instantiated to create such a callable.
  486. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
  487. compute: Either "tasks" (default) to use Ray Tasks or an
  488. :class:`~ray.data.ActorPoolStrategy` to use an autoscaling actor pool.
  489. concurrency: The number of Ray workers to use concurrently.
  490. Returns:
  491. The `ComputeStrategy` for execution.
  492. """
  493. # Lazily import these objects to avoid circular imports.
  494. from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy
  495. from ray.data.block import CallableClass
  496. if isinstance(fn, CallableClass):
  497. is_callable_class = True
  498. else:
  499. # TODO(chengsu): disallow object that is not a function. For example,
  500. # An object instance of class often indicates a bug in user code.
  501. is_callable_class = False
  502. if fn_constructor_args is not None:
  503. raise ValueError(
  504. "``fn_constructor_args`` can only be specified if providing a "
  505. f"callable class instance for ``fn``, but got: {fn}."
  506. )
  507. if compute is not None:
  508. if is_callable_class and (
  509. compute == "tasks" or isinstance(compute, TaskPoolStrategy)
  510. ):
  511. raise ValueError(
  512. f"You specified the callable class {fn} as your UDF with the compute "
  513. f"{compute}, but Ray Data can't schedule callable classes with the task "
  514. f"pool strategy. To fix this error, pass an ActorPoolStrategy to compute or "
  515. f"None to use the default compute strategy."
  516. )
  517. elif not is_callable_class and (
  518. compute == "actors" or isinstance(compute, ActorPoolStrategy)
  519. ):
  520. raise ValueError(
  521. f"You specified the function {fn} as your UDF with the compute "
  522. f"{compute}, but Ray Data can't schedule regular functions with the actor "
  523. f"pool strategy. To fix this error, pass a TaskPoolStrategy to compute or "
  524. f"None to use the default compute strategy."
  525. )
  526. return compute
  527. elif concurrency is not None:
  528. # Legacy code path to support `concurrency` argument.
  529. logger.warning(
  530. "The argument ``concurrency`` is deprecated in Ray 2.51. Please specify "
  531. "argument ``compute`` instead. For more information, see "
  532. "https://docs.ray.io/en/master/data/transforming-data.html#"
  533. "stateful-transforms."
  534. )
  535. if isinstance(concurrency, tuple):
  536. # Validate tuple length and that all elements are integers
  537. if len(concurrency) not in (2, 3) or not all(
  538. isinstance(c, int) for c in concurrency
  539. ):
  540. raise ValueError(
  541. "``concurrency`` is expected to be set as a tuple of "
  542. f"integers, but got: {concurrency}."
  543. )
  544. # Check if function is callable class (common validation)
  545. if not is_callable_class:
  546. raise ValueError(
  547. "``concurrency`` is set as a tuple of integers, but ``fn`` "
  548. f"is not a callable class: {fn}. Use ``concurrency=n`` to "
  549. "control maximum number of workers to use."
  550. )
  551. # Create ActorPoolStrategy based on tuple length
  552. if len(concurrency) == 2:
  553. return ActorPoolStrategy(
  554. min_size=concurrency[0], max_size=concurrency[1]
  555. )
  556. else: # len(concurrency) == 3
  557. return ActorPoolStrategy(
  558. min_size=concurrency[0],
  559. max_size=concurrency[1],
  560. initial_size=concurrency[2],
  561. )
  562. elif isinstance(concurrency, int):
  563. if is_callable_class:
  564. return ActorPoolStrategy(size=concurrency)
  565. else:
  566. return TaskPoolStrategy(size=concurrency)
  567. else:
  568. raise ValueError(
  569. "``concurrency`` is expected to be set as an integer or a "
  570. f"tuple of integers, but got: {concurrency}."
  571. )
  572. else:
  573. if is_callable_class:
  574. return ActorPoolStrategy(min_size=1, max_size=None)
  575. else:
  576. return TaskPoolStrategy()
  577. def capfirst(s: str):
  578. """Capitalize the first letter of a string
  579. Args:
  580. s: String to capitalize
  581. Returns:
  582. Capitalized string
  583. """
  584. return s[0].upper() + s[1:]
  585. def capitalize(s: str):
  586. """Capitalize a string, removing '_' and keeping camelcase.
  587. Args:
  588. s: String to capitalize
  589. Returns:
  590. Capitalized string with no underscores.
  591. """
  592. return "".join(capfirst(x) for x in s.split("_"))
  593. def pandas_df_to_arrow_block(
  594. df: "pandas.DataFrame",
  595. ) -> Tuple["Block", "BlockMetadataWithSchema"]:
  596. from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema
  597. block = BlockAccessor.for_block(df).to_arrow()
  598. stats = BlockExecStats.builder()
  599. return block, BlockMetadataWithSchema.from_block(block, stats=stats.build())
  600. def ndarray_to_block(
  601. ndarray: np.ndarray, ctx: DataContext
  602. ) -> Tuple["Block", "BlockMetadataWithSchema"]:
  603. from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema
  604. DataContext._set_current(ctx)
  605. stats = BlockExecStats.builder()
  606. block = BlockAccessor.batch_to_block({"data": ndarray})
  607. return block, BlockMetadataWithSchema.from_block(block, stats=stats.build())
  608. def get_table_block_metadata_schema(
  609. table: Union["pyarrow.Table", "pandas.DataFrame"],
  610. ) -> "BlockMetadataWithSchema":
  611. from ray.data.block import BlockExecStats, BlockMetadataWithSchema
  612. stats = BlockExecStats.builder()
  613. return BlockMetadataWithSchema.from_block(table, stats=stats.build())
  614. def unify_block_metadata_schema(
  615. block_metadata_with_schemas: List["BlockMetadataWithSchema"],
  616. ) -> Optional["Schema"]:
  617. """For the input list of BlockMetadata, return a unified schema of the
  618. corresponding blocks. If the metadata have no valid schema, returns None.
  619. Args:
  620. block_metadata_with_schemas: List of BlockMetadata to unify
  621. Returns:
  622. A unified schema of the input list of schemas, or None if no valid schemas
  623. are provided.
  624. """
  625. # Some blocks could be empty, in which case we cannot get their schema.
  626. # TODO(ekl) validate schema is the same across different blocks.
  627. # First check if there are blocks with computed schemas, then unify
  628. # valid schemas from all such blocks.
  629. schemas_to_unify = []
  630. for m in block_metadata_with_schemas:
  631. if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
  632. schemas_to_unify.append(m.schema)
  633. return unify_schemas_with_validation(schemas_to_unify)
  634. def unify_schemas_with_validation(
  635. schemas_to_unify: Iterable["Schema"],
  636. ) -> Optional["Schema"]:
  637. if schemas_to_unify:
  638. from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
  639. # Check valid pyarrow installation before attempting schema unification
  640. try:
  641. import pyarrow as pa
  642. except ImportError:
  643. pa = None
  644. # If the result contains PyArrow schemas, unify them
  645. if pa is not None and all(isinstance(s, pa.Schema) for s in schemas_to_unify):
  646. return unify_schemas(schemas_to_unify, promote_types=True)
  647. # Otherwise, if the resulting schemas are simple types (e.g. int),
  648. # return the first schema.
  649. return schemas_to_unify[0]
  650. return None
  651. def unify_ref_bundles_schema(
  652. ref_bundles: List["RefBundle"],
  653. ) -> Optional["Schema"]:
  654. schemas_to_unify = []
  655. for bundle in ref_bundles:
  656. if bundle.schema is not None and (
  657. bundle.num_rows() is None or bundle.num_rows() > 0
  658. ):
  659. schemas_to_unify.append(bundle.schema)
  660. return unify_schemas_with_validation(schemas_to_unify)
  661. def find_partition_index(
  662. table: Union["pyarrow.Table", "pandas.DataFrame"],
  663. desired: Tuple[Union[int, float]],
  664. sort_key: "SortKey",
  665. ) -> int:
  666. """For the given block, find the index where the desired value should be
  667. added, to maintain sorted order.
  668. We do this by iterating over each column, starting with the primary sort key,
  669. and binary searching for the desired value in the column. Each binary search
  670. shortens the "range" of indices (represented by ``left`` and ``right``, which
  671. are indices of rows) where the desired value could be inserted.
  672. Args:
  673. table: The block to search in.
  674. desired: A single tuple representing the boundary to partition at.
  675. ``len(desired)`` must be less than or equal to the number of columns
  676. being sorted.
  677. sort_key: The sort key to use for sorting, providing the columns to be
  678. sorted and their directions.
  679. Returns:
  680. The index where the desired value should be inserted to maintain sorted
  681. order.
  682. """
  683. columns = sort_key.get_columns()
  684. descending = sort_key.get_descending()
  685. left, right = 0, len(table)
  686. for i in range(len(desired)):
  687. if left == right:
  688. return right
  689. col_name = columns[i]
  690. col_vals = table[col_name].to_numpy()[left:right]
  691. desired_val = desired[i]
  692. # Handle null values - replace them with sentinel values
  693. if desired_val is None:
  694. desired_val = NULL_SENTINEL
  695. prevleft = left
  696. if descending[i] is True:
  697. # ``np.searchsorted`` expects the array to be sorted in ascending
  698. # order, so we pass ``sorter``, which is an array of integer indices
  699. # that sort ``col_vals`` into ascending order. The returned index
  700. # is an index into the ascending order of ``col_vals``, so we need
  701. # to subtract it from ``len(col_vals)`` to get the index in the
  702. # original descending order of ``col_vals``.
  703. sorter = np.arange(len(col_vals) - 1, -1, -1)
  704. left = prevleft + (
  705. len(col_vals)
  706. - np.searchsorted(
  707. col_vals,
  708. desired_val,
  709. side="right",
  710. sorter=sorter,
  711. )
  712. )
  713. right = prevleft + (
  714. len(col_vals)
  715. - np.searchsorted(
  716. col_vals,
  717. desired_val,
  718. side="left",
  719. sorter=sorter,
  720. )
  721. )
  722. else:
  723. left = prevleft + np.searchsorted(col_vals, desired_val, side="left")
  724. right = prevleft + np.searchsorted(col_vals, desired_val, side="right")
  725. return right if descending[0] is True else left
  726. def get_attribute_from_class_name(class_name: str) -> Any:
  727. """Get Python attribute from the provided class name.
  728. The caller needs to make sure the provided class name includes
  729. full module name, and can be imported successfully.
  730. """
  731. from importlib import import_module
  732. paths = class_name.split(".")
  733. if len(paths) < 2:
  734. raise ValueError(f"Cannot create object from {class_name}.")
  735. module_name = ".".join(paths[:-1])
  736. attribute_name = paths[-1]
  737. return getattr(import_module(module_name), attribute_name)
  738. T = TypeVar("T")
  739. U = TypeVar("U")
  740. class _InterruptibleQueue(Queue):
  741. """Extension of Python's `queue.Queue` providing ability to get interrupt its
  742. method callers in other threads"""
  743. INTERRUPTION_CHECK_FREQUENCY_SEC = 0.5
  744. def __init__(
  745. self, max_size: int, interrupted_event: Optional[threading.Event] = None
  746. ):
  747. super().__init__(maxsize=max_size)
  748. self._interrupted_event = interrupted_event or threading.Event()
  749. def get(self, block=True, timeout=None):
  750. if not block or timeout is not None:
  751. return super().get(block, timeout)
  752. # In case when the call is blocking and no timeout is specified (ie blocking
  753. # indefinitely) we apply the following protocol to make it interruptible:
  754. #
  755. # 1. `Queue.get` is invoked w/ 500ms timeout
  756. # 2. `Empty` exception is intercepted (will be raised upon timeout elapsing)
  757. # 3. If interrupted flag is set `InterruptedError` is raised
  758. # 4. Otherwise, protocol retried (until interrupted or queue
  759. # becoming non-empty)
  760. while True:
  761. if self._interrupted_event.is_set():
  762. raise InterruptedError()
  763. try:
  764. return super().get(
  765. block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
  766. )
  767. except Empty:
  768. pass
  769. def put(self, item, block=True, timeout=None):
  770. if not block or timeout is not None:
  771. super().put(item, block, timeout)
  772. return
  773. # In case when the call is blocking and no timeout is specified (ie blocking
  774. # indefinitely) we apply the following protocol to make it interruptible:
  775. #
  776. # 1. `Queue.pet` is invoked w/ 500ms timeout
  777. # 2. `Full` exception is intercepted (will be raised upon timeout elapsing)
  778. # 3. If interrupted flag is set `InterruptedError` is raised
  779. # 4. Otherwise, protocol retried (until interrupted or queue
  780. # becomes non-full)
  781. while True:
  782. if self._interrupted_event.is_set():
  783. raise InterruptedError()
  784. try:
  785. super().put(
  786. item, block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
  787. )
  788. return
  789. except Full:
  790. pass
  791. def make_async_gen(
  792. base_iterator: Iterator[T],
  793. fn: Callable[[Iterator[T]], Iterator[U]],
  794. preserve_ordering: bool,
  795. num_workers: int = 1,
  796. buffer_size: int = 1,
  797. ) -> Generator[U, None, None]:
  798. """Returns a generator (iterator) mapping items from the
  799. provided iterator applying provided transformation in parallel (using a
  800. thread-pool).
  801. NOTE: There are some important constraints that needs to be carefully
  802. understood before using this method
  803. 1. If `preserve_ordering` is True
  804. a. This method would unroll input iterator eagerly (irrespective
  805. of the speed of resulting generator being consumed). This is necessary
  806. as we can not guarantee liveness of the algorithm AND preserving of the
  807. original ordering at the same time.
  808. b. Resulting ordering of the output will "match" ordering of the input, ie
  809. that:
  810. iterator = [A1, A2, ... An]
  811. output iterator = [map(A1), map(A2), ..., map(An)]
  812. 2. If `preserve_ordering` is False
  813. a. No more than `num_workers * (queue_buffer_size + 1)` elements will be
  814. fetched from the iterator
  815. b. Resulting ordering of the output is unspecified (and is
  816. non-deterministic)
  817. Args:
  818. base_iterator: Iterator yielding elements to map
  819. fn: Transformation to apply to each element
  820. preserve_ordering: Whether ordering has to be preserved
  821. num_workers: The number of threads to use in the threadpool (defaults to 1)
  822. buffer_size: Number of objects to be buffered in its input/output
  823. queues (per queue; defaults to 2). Total number of objects held
  824. in memory could be calculated as:
  825. num_workers * buffer_size * 2 (input and output)
  826. Returns:
  827. An generator (iterator) of the elements corresponding to the source
  828. elements mapped by provided transformation (while *preserving the ordering*)
  829. """
  830. gen_id = random.randint(0, 2**31 - 1)
  831. if num_workers < 1:
  832. raise ValueError("Size of threadpool must be at least 1.")
  833. # Signal handler used to interrupt workers when terminating
  834. interrupted_event = threading.Event()
  835. # To apply transformations to elements in parallel *and* preserve the ordering
  836. # following invariants are established:
  837. # - Every worker is handled by standalone thread
  838. # - Every worker is assigned an input and an output queue
  839. #
  840. # And following protocol is implemented:
  841. # - Filling worker traverses input iterator round-robin'ing elements across
  842. # the input queues (in order!)
  843. # - Transforming workers traverse respective input queue in-order: de-queueing
  844. # element, applying transformation and enqueuing the result into the output
  845. # queue
  846. # - Generator (returned from this method) traverses output queues (in the same
  847. # order as input queues) dequeues 1 mapped element at a time from each output
  848. # queue and yields it
  849. #
  850. # However, in case when we're preserving the ordering we can not enforce the input
  851. # queue size as this could result in deadlocks since transformations could be
  852. # producing sequences of arbitrary length.
  853. #
  854. # Check `test_make_async_gen_varying_seq_length_stress_test` for more context on
  855. # this problem.
  856. if preserve_ordering:
  857. input_queue_buf_size = -1
  858. num_input_queues = num_workers
  859. else:
  860. input_queue_buf_size = (buffer_size + 1) * num_workers
  861. num_input_queues = 1
  862. input_queues = [
  863. _InterruptibleQueue(input_queue_buf_size, interrupted_event)
  864. for _ in range(num_input_queues)
  865. ]
  866. output_queues = [
  867. _InterruptibleQueue(buffer_size, interrupted_event) for _ in range(num_workers)
  868. ]
  869. # Filling worker
  870. def _run_filling_worker():
  871. try:
  872. # First, round-robin elements from the iterator into
  873. # corresponding input queues (one by one)
  874. for idx, item in enumerate(base_iterator):
  875. input_queues[idx % num_input_queues].put(item)
  876. # NOTE: We have to Enqueue sentinel objects for every transforming
  877. # worker:
  878. # - In case of preserving order of ``num_queues`` == ``num_workers``
  879. # we will enqueue 1 sentinel per queue
  880. # - In case of NOT preserving order all ``num_workers`` sentinels
  881. # will be enqueued into a single queue
  882. for idx in range(num_workers):
  883. input_queues[idx % num_input_queues].put(SENTINEL)
  884. except InterruptedError:
  885. pass
  886. except Exception as e:
  887. logger.warning("Caught exception in filling worker!", exc_info=e)
  888. # In case of filling worker encountering an exception we have to propagate
  889. # it back to the (main) iterating thread. To achieve that we're traversing
  890. # output queues *backwards* relative to the order of iterator-thread such
  891. # that they are more likely to meet w/in a single iteration.
  892. for output_queue in reversed(output_queues):
  893. output_queue.put(e)
  894. # Transforming worker
  895. def _run_transforming_worker(input_queue, output_queue):
  896. try:
  897. # Create iterator draining the queue, until it receives sentinel
  898. #
  899. # NOTE: `queue.get` is blocking!
  900. input_queue_iter = iter(input_queue.get, SENTINEL)
  901. for result in fn(input_queue_iter):
  902. # Enqueue result of the transformation
  903. output_queue.put(result)
  904. # Enqueue sentinel (to signal that transformations are completed)
  905. output_queue.put(SENTINEL)
  906. except InterruptedError:
  907. pass
  908. except Exception as e:
  909. logger.warning("Caught exception in transforming worker!", exc_info=e)
  910. # NOTE: In this case we simply enqueue the exception rather than
  911. # interrupting
  912. output_queue.put(e)
  913. # Start workers threads
  914. filling_worker_thread = threading.Thread(
  915. target=_run_filling_worker,
  916. name=f"map_tp_filling_worker-{gen_id}",
  917. daemon=True,
  918. )
  919. filling_worker_thread.start()
  920. transforming_worker_threads = [
  921. threading.Thread(
  922. target=_run_transforming_worker,
  923. name=f"map_tp_transforming_worker-{gen_id}-{idx}",
  924. args=(input_queues[idx % num_input_queues], output_queues[idx]),
  925. daemon=True,
  926. )
  927. for idx in range(num_workers)
  928. ]
  929. for t in transforming_worker_threads:
  930. t.start()
  931. # Use main thread to yield output batches
  932. try:
  933. # Keep track of remaining non-empty output queues
  934. remaining_output_queues = output_queues
  935. while len(remaining_output_queues) > 0:
  936. # To provide deterministic ordering of the produced iterator we rely
  937. # on the following invariants:
  938. #
  939. # - Elements from the original iterator are round-robin'd into
  940. # input queues (in order)
  941. # - Individual workers drain their respective input queues populating
  942. # output queues with the results of applying transformation to the
  943. # original item (and hence preserving original ordering of the input
  944. # queue)
  945. # - To yield from the generator output queues are traversed in the same
  946. # order and one single element is dequeued (in a blocking way!) at a
  947. # time from every individual output queue
  948. #
  949. empty_queues = []
  950. # At every iteration only remaining non-empty queues
  951. # are traversed (to prevent blocking on exhausted queue)
  952. for output_queue in remaining_output_queues:
  953. # NOTE: This is blocking!
  954. item = output_queue.get()
  955. if isinstance(item, Exception):
  956. raise item
  957. if item is SENTINEL:
  958. empty_queues.append(output_queue)
  959. else:
  960. yield item
  961. if empty_queues:
  962. remaining_output_queues = [
  963. q for q in remaining_output_queues if q not in empty_queues
  964. ]
  965. finally:
  966. # Set flag to interrupt workers (to make sure no dangling
  967. # threads holding the objects are left behind)
  968. #
  969. # NOTE: Interrupted event is set to interrupt the running threads
  970. # that might be blocked otherwise waiting on inputs from respective
  971. # queues. However, even though we're interrupting the threads we can't
  972. # guarantee that threads will be interrupted in time (as this is
  973. # dependent on Python's GC finalizer to close the generator by raising
  974. # `GeneratorExit`) and hence we can't join on either filling or
  975. # transforming workers.
  976. interrupted_event.set()
  977. class RetryingContextManager:
  978. def __init__(
  979. self,
  980. f: pyarrow.NativeFile,
  981. context: DataContext,
  982. max_attempts: int = 10,
  983. max_backoff_s: int = 32,
  984. ):
  985. self._f = f
  986. self._data_context = context
  987. self._max_attempts = max_attempts
  988. self._max_backoff_s = max_backoff_s
  989. def __repr__(self):
  990. return f"<{self.__class__.__name__} fs={self.handler.unwrap()}>"
  991. def _retry_operation(self, operation: Callable, description: str):
  992. """Execute an operation with retries."""
  993. return call_with_retry(
  994. operation,
  995. description=description,
  996. match=self._data_context.retried_io_errors,
  997. max_attempts=self._max_attempts,
  998. max_backoff_s=self._max_backoff_s,
  999. )
  1000. def __enter__(self):
  1001. return self._retry_operation(self._f.__enter__, "enter file context")
  1002. def __exit__(self, exc_type, exc_value, traceback):
  1003. self._retry_operation(
  1004. lambda: self._f.__exit__(exc_type, exc_value, traceback),
  1005. "exit file context",
  1006. )
  1007. class RetryingPyFileSystem(pyarrow.fs.PyFileSystem):
  1008. def __init__(self, handler: "RetryingPyFileSystemHandler"):
  1009. if not isinstance(handler, RetryingPyFileSystemHandler):
  1010. assert ValueError("handler must be a RetryingPyFileSystemHandler")
  1011. super().__init__(handler)
  1012. @property
  1013. def retryable_errors(self) -> List[str]:
  1014. return self.handler._retryable_errors
  1015. def unwrap(self):
  1016. return self.handler.unwrap()
  1017. @classmethod
  1018. def wrap(
  1019. cls,
  1020. fs: "pyarrow.fs.FileSystem",
  1021. retryable_errors: List[str],
  1022. max_attempts: int = 10,
  1023. max_backoff_s: int = 32,
  1024. ):
  1025. if isinstance(fs, RetryingPyFileSystem):
  1026. return fs
  1027. handler = RetryingPyFileSystemHandler(
  1028. fs, retryable_errors, max_attempts, max_backoff_s
  1029. )
  1030. return cls(handler)
  1031. def __reduce__(self):
  1032. # Serialization of this class breaks for some reason without this
  1033. return (self.__class__, (self.handler,))
  1034. @classmethod
  1035. def __setstate__(cls, state):
  1036. # Serialization of this class breaks for some reason without this
  1037. return cls(*state)
  1038. class RetryingPyFileSystemHandler(pyarrow.fs.FileSystemHandler):
  1039. """Wrapper for filesystem objects that adds retry functionality for file operations.
  1040. This class wraps any filesystem object and adds automatic retries for common
  1041. file operations that may fail transiently.
  1042. """
  1043. def __init__(
  1044. self,
  1045. fs: "pyarrow.fs.FileSystem",
  1046. retryable_errors: List[str] = tuple(),
  1047. max_attempts: int = 10,
  1048. max_backoff_s: int = 32,
  1049. ):
  1050. """Initialize the retrying filesystem wrapper.
  1051. Args:
  1052. fs: The underlying filesystem to wrap
  1053. context: DataContext for retry settings
  1054. max_attempts: Maximum number of retry attempts
  1055. max_backoff_s: Maximum backoff time in seconds
  1056. """
  1057. assert not isinstance(
  1058. fs, RetryingPyFileSystem
  1059. ), "Cannot wrap a RetryingPyFileSystem"
  1060. self._fs = fs
  1061. self._retryable_errors = retryable_errors
  1062. self._max_attempts = max_attempts
  1063. self._max_backoff_s = max_backoff_s
  1064. def _retry_operation(self, operation: Callable, description: str):
  1065. """Execute an operation with retries."""
  1066. return call_with_retry(
  1067. operation,
  1068. description=description,
  1069. match=self._retryable_errors,
  1070. max_attempts=self._max_attempts,
  1071. max_backoff_s=self._max_backoff_s,
  1072. )
  1073. def unwrap(self):
  1074. return self._fs
  1075. def copy_file(self, src: str, dest: str):
  1076. """Copy a file."""
  1077. return self._retry_operation(
  1078. lambda: self._fs.copy_file(src, dest), f"copy file from {src} to {dest}"
  1079. )
  1080. def create_dir(self, path: str, recursive: bool):
  1081. """Create a directory and subdirectories."""
  1082. return self._retry_operation(
  1083. lambda: self._fs.create_dir(path, recursive=recursive),
  1084. f"create directory {path}",
  1085. )
  1086. def delete_dir(self, path: str):
  1087. """Delete a directory and its contents, recursively."""
  1088. return self._retry_operation(
  1089. lambda: self._fs.delete_dir(path), f"delete directory {path}"
  1090. )
  1091. def delete_dir_contents(self, path: str, missing_dir_ok: bool = False):
  1092. """Delete a directory's contents, recursively."""
  1093. return self._retry_operation(
  1094. lambda: self._fs.delete_dir_contents(path, missing_dir_ok=missing_dir_ok),
  1095. f"delete directory contents {path}",
  1096. )
  1097. def delete_file(self, path: str):
  1098. """Delete a file."""
  1099. return self._retry_operation(
  1100. lambda: self._fs.delete_file(path), f"delete file {path}"
  1101. )
  1102. def delete_root_dir_contents(self):
  1103. return self._retry_operation(
  1104. lambda: self._fs.delete_dir_contents("/", accept_root_dir=True),
  1105. "delete root dir contents",
  1106. )
  1107. def equals(self, other: "pyarrow.fs.FileSystem") -> bool:
  1108. """Test if this filesystem equals another."""
  1109. return self._fs.equals(other)
  1110. def get_file_info(self, paths: List[str]):
  1111. """Get info for the given files."""
  1112. return self._retry_operation(
  1113. lambda: self._fs.get_file_info(paths),
  1114. f"get file info for {paths}",
  1115. )
  1116. def get_file_info_selector(self, selector):
  1117. return self._retry_operation(
  1118. lambda: self._fs.get_file_info(selector),
  1119. f"get file info for {selector}",
  1120. )
  1121. def get_type_name(self):
  1122. return "RetryingPyFileSystem"
  1123. def move(self, src: str, dest: str):
  1124. """Move / rename a file or directory."""
  1125. return self._retry_operation(
  1126. lambda: self._fs.move(src, dest), f"move from {src} to {dest}"
  1127. )
  1128. def normalize_path(self, path: str) -> str:
  1129. """Normalize filesystem path."""
  1130. return self._retry_operation(
  1131. lambda: self._fs.normalize_path(path), f"normalize path {path}"
  1132. )
  1133. def open_append_stream(
  1134. self,
  1135. path: str,
  1136. metadata=None,
  1137. ) -> "pyarrow.NativeFile":
  1138. """Open an output stream for appending.
  1139. Compression is disabled in this method because it is handled in the
  1140. PyFileSystem abstract class.
  1141. """
  1142. return self._retry_operation(
  1143. lambda: self._fs.open_append_stream(
  1144. path,
  1145. compression=None,
  1146. metadata=metadata,
  1147. ),
  1148. f"open append stream for {path}",
  1149. )
  1150. def open_input_stream(
  1151. self,
  1152. path: str,
  1153. ) -> "pyarrow.NativeFile":
  1154. """Open an input stream for sequential reading.
  1155. Compression is disabled in this method because it is handled in the
  1156. PyFileSystem abstract class.
  1157. """
  1158. return self._retry_operation(
  1159. lambda: self._fs.open_input_stream(path, compression=None),
  1160. f"open input stream for {path}",
  1161. )
  1162. def open_output_stream(
  1163. self,
  1164. path: str,
  1165. metadata=None,
  1166. ) -> "pyarrow.NativeFile":
  1167. """Open an output stream for sequential writing."
  1168. Compression is disabled in this method because it is handled in the
  1169. PyFileSystem abstract class.
  1170. """
  1171. return self._retry_operation(
  1172. lambda: self._fs.open_output_stream(
  1173. path,
  1174. compression=None,
  1175. metadata=metadata,
  1176. ),
  1177. f"open output stream for {path}",
  1178. )
  1179. def open_input_file(self, path: str) -> "pyarrow.NativeFile":
  1180. """Open an input file for random access reading."""
  1181. return self._retry_operation(
  1182. lambda: self._fs.open_input_file(path), f"open input file {path}"
  1183. )
  1184. def iterate_with_retry(
  1185. iterable_factory: Callable[[], Iterable],
  1186. description: str,
  1187. *,
  1188. match: Optional[List[str]] = None,
  1189. max_attempts: int = 10,
  1190. max_backoff_s: int = 32,
  1191. ) -> Any:
  1192. """Iterate through an iterable with retries.
  1193. If the iterable raises an exception, this function recreates and re-iterates
  1194. through the iterable, while skipping the items that have already been yielded.
  1195. Args:
  1196. iterable_factory: A no-argument function that creates the iterable.
  1197. match: A list of strings to match in the exception message. If ``None``, any
  1198. error is retried.
  1199. description: An imperitive description of the function being retried. For
  1200. example, "open the file".
  1201. max_attempts: The maximum number of attempts to retry.
  1202. max_backoff_s: The maximum number of seconds to backoff.
  1203. """
  1204. assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
  1205. num_items_yielded = 0
  1206. for attempt in range(max_attempts):
  1207. try:
  1208. iterable = iterable_factory()
  1209. for item_index, item in enumerate(iterable):
  1210. if item_index < num_items_yielded:
  1211. # Skip items that have already been yielded.
  1212. continue
  1213. num_items_yielded += 1
  1214. yield item
  1215. return
  1216. except Exception as e:
  1217. is_retryable = match is None or any(pattern in str(e) for pattern in match)
  1218. if is_retryable and attempt + 1 < max_attempts:
  1219. # Retry with binary expoential backoff with random jitter.
  1220. backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random()
  1221. logger.debug(
  1222. f"Retrying {attempt+1} attempts to {description} "
  1223. f"after {backoff} seconds."
  1224. )
  1225. time.sleep(backoff)
  1226. else:
  1227. raise e from None
  1228. def convert_bytes_to_human_readable_str(num_bytes: int) -> str:
  1229. if num_bytes >= 1e9:
  1230. num_bytes_str = f"{round(num_bytes / 1e9)}GB"
  1231. elif num_bytes >= 1e6:
  1232. num_bytes_str = f"{round(num_bytes / 1e6)}MB"
  1233. else:
  1234. num_bytes_str = f"{round(num_bytes / 1e3)}KB"
  1235. return num_bytes_str
  1236. def _validate_rows_per_file_args(
  1237. *,
  1238. num_rows_per_file: Optional[int] = None,
  1239. min_rows_per_file: Optional[int] = None,
  1240. max_rows_per_file: Optional[int] = None,
  1241. ) -> Tuple[Optional[int], Optional[int]]:
  1242. """Helper method to validate and handle rows per file arguments.
  1243. Args:
  1244. num_rows_per_file: Deprecated parameter for number of rows per file
  1245. min_rows_per_file: New parameter for minimum rows per file
  1246. max_rows_per_file: New parameter for maximum rows per file
  1247. Returns:
  1248. A tuple of (effective_min_rows_per_file, effective_max_rows_per_file)
  1249. """
  1250. if num_rows_per_file is not None:
  1251. import warnings
  1252. warnings.warn(
  1253. "`num_rows_per_file` is deprecated and will be removed in a future release. "
  1254. "Use `min_rows_per_file` instead.",
  1255. DeprecationWarning,
  1256. stacklevel=3,
  1257. )
  1258. if min_rows_per_file is not None:
  1259. raise ValueError(
  1260. "Cannot specify both `num_rows_per_file` and `min_rows_per_file`. "
  1261. "Use `min_rows_per_file` as `num_rows_per_file` is deprecated."
  1262. )
  1263. min_rows_per_file = num_rows_per_file
  1264. # Validate max_rows_per_file
  1265. if max_rows_per_file is not None and max_rows_per_file <= 0:
  1266. raise ValueError("max_rows_per_file must be a positive integer")
  1267. # Validate min_rows_per_file
  1268. if min_rows_per_file is not None and min_rows_per_file <= 0:
  1269. raise ValueError("min_rows_per_file must be a positive integer")
  1270. # Validate that max >= min if both are specified
  1271. if (
  1272. min_rows_per_file is not None
  1273. and max_rows_per_file is not None
  1274. and min_rows_per_file > max_rows_per_file
  1275. ):
  1276. raise ValueError(
  1277. f"min_rows_per_file ({min_rows_per_file}) cannot be greater than "
  1278. f"max_rows_per_file ({max_rows_per_file})"
  1279. )
  1280. return min_rows_per_file, max_rows_per_file
  1281. def is_nan(value) -> bool:
  1282. """Returns true if provide value is ``np.nan``"""
  1283. try:
  1284. return isinstance(value, float) and np.isnan(value)
  1285. except TypeError:
  1286. return False
  1287. def is_null(value: Any) -> bool:
  1288. """This generalization of ``is_nan`` util qualifying both None and np.nan
  1289. as null values"""
  1290. return value is None or is_nan(value)
  1291. def keys_equal(keys1, keys2):
  1292. if len(keys1) != len(keys2):
  1293. return False
  1294. for k1, k2 in zip(keys1, keys2):
  1295. if not ((is_nan(k1) and is_nan(k2)) or k1 == k2):
  1296. return False
  1297. return True
  1298. def get_total_obj_store_mem_on_node() -> int:
  1299. """Return the total object store memory on the current node.
  1300. This function incurs an RPC. Use it cautiously.
  1301. """
  1302. node_id = ray.get_runtime_context().get_node_id()
  1303. total_resources_per_node = ray._private.state.total_resources_per_node()
  1304. assert (
  1305. node_id in total_resources_per_node
  1306. ), f"Expected node '{node_id}' to be in resources: {total_resources_per_node}"
  1307. return total_resources_per_node[node_id]["object_store_memory"]
  1308. class MemoryProfiler:
  1309. """A context manager that polls the USS of the current process.
  1310. This class approximates the max USS by polling memory and subtracting the amount
  1311. of shared memory from the resident set size (RSS). It's not a
  1312. perfect estimate (it can underestimate, e.g., if you use Torch tensors), but
  1313. estimating the USS is much cheaper than computing the actual USS.
  1314. .. warning::
  1315. This class only works with Linux. If you use it on another platform,
  1316. `estimate_max_uss` always returns ``None``.
  1317. Example:
  1318. .. testcode::
  1319. with MemoryProfiler(poll_interval_s=1.0) as profiler:
  1320. for i in range(10):
  1321. ... # Your code here
  1322. print(f"Max USS: {profiler.estimate_max_uss()}")
  1323. profiler.reset()
  1324. """
  1325. def __init__(self, poll_interval_s: Optional[float]):
  1326. """
  1327. Args:
  1328. poll_interval_s: The interval to poll the USS of the process. If `None`,
  1329. this class won't poll the USS.
  1330. """
  1331. self._poll_interval_s = poll_interval_s
  1332. self._process = psutil.Process(os.getpid())
  1333. self._max_uss = None
  1334. self._max_uss_lock = threading.Lock()
  1335. self._uss_poll_thread = None
  1336. self._stop_uss_poll_event = None
  1337. def __repr__(self):
  1338. return f"MemoryProfiler(poll_interval_s={self._poll_interval_s})"
  1339. def __enter__(self):
  1340. if self._can_estimate_uss() and self._poll_interval_s is not None:
  1341. (
  1342. self._uss_poll_thread,
  1343. self._stop_uss_poll_event,
  1344. ) = self._start_uss_poll_thread()
  1345. return self
  1346. def __exit__(self, exc_type, exc_val, exc_tb):
  1347. if self._uss_poll_thread is not None:
  1348. self._stop_uss_poll_thread()
  1349. def estimate_max_uss(self) -> Optional[int]:
  1350. """Get an estimate of the max USS of the current process.
  1351. Returns:
  1352. An estimate of the max USS of the process in bytes, or ``None`` if an
  1353. estimate isn't available.
  1354. """
  1355. if not self._can_estimate_uss():
  1356. assert self._max_uss is None
  1357. return None
  1358. with self._max_uss_lock:
  1359. if self._max_uss is None:
  1360. self._max_uss = self._estimate_uss()
  1361. else:
  1362. self._max_uss = max(self._max_uss, self._estimate_uss())
  1363. assert self._max_uss is not None
  1364. return self._max_uss
  1365. def reset(self):
  1366. with self._max_uss_lock:
  1367. self._max_uss = None
  1368. def _start_uss_poll_thread(self) -> Tuple[threading.Thread, threading.Event]:
  1369. assert self._poll_interval_s is not None
  1370. assert self._can_estimate_uss()
  1371. stop_event = threading.Event()
  1372. def poll_uss():
  1373. while not stop_event.is_set():
  1374. with self._max_uss_lock:
  1375. if self._max_uss is None:
  1376. self._max_uss = self._estimate_uss()
  1377. else:
  1378. self._max_uss = max(self._max_uss, self._estimate_uss())
  1379. stop_event.wait(self._poll_interval_s)
  1380. thread = threading.Thread(target=poll_uss, daemon=True)
  1381. thread.start()
  1382. return thread, stop_event
  1383. def _stop_uss_poll_thread(self):
  1384. if self._stop_uss_poll_event is not None:
  1385. self._stop_uss_poll_event.set()
  1386. self._uss_poll_thread.join()
  1387. def _estimate_uss(self) -> int:
  1388. assert self._can_estimate_uss()
  1389. memory_info = self._process.memory_info()
  1390. # Estimate the USS (the amount of memory that'd be free if we killed the
  1391. # process right now) as the difference between the RSS (total physical memory)
  1392. # and amount of shared physical memory.
  1393. return memory_info.rss - memory_info.shared
  1394. @staticmethod
  1395. @functools.cache
  1396. def _can_estimate_uss() -> bool:
  1397. # MacOS and Windows don't have the 'shared' attribute of `memory_info()`.
  1398. return platform.system() == "Linux"
  1399. def unzip(data: List[Tuple[Any, ...]]) -> Tuple[List[Any], ...]:
  1400. """Unzips a list of tuples into a tuple of lists
  1401. Args:
  1402. data: A list of tuples to unzip.
  1403. Returns:
  1404. A tuple of lists, where each list corresponds to one element of the tuples in
  1405. the input list.
  1406. """
  1407. return tuple(map(list, zip(*data)))
  1408. def _sort_df(df: pd.DataFrame) -> pd.DataFrame:
  1409. """Sort DataFrame by columns and rows, and also handle unhashable types."""
  1410. df = df.copy()
  1411. def to_sortable(x):
  1412. if isinstance(x, (list, np.ndarray)):
  1413. return tuple(to_sortable(i) for i in x)
  1414. if isinstance(x, dict):
  1415. return tuple(sorted((k, to_sortable(v)) for k, v in x.items()))
  1416. return x
  1417. sort_cols = []
  1418. temp_cols = []
  1419. # Sort by all columns to ensure deterministic order.
  1420. columns = sorted(df.columns)
  1421. for col in columns:
  1422. if df[col].dtype == "object":
  1423. # Create a temporary column for sorting to handle unhashable types.
  1424. # Use UUID to avoid collisions with existing column names.
  1425. temp_col = f"__sort_proxy_{uuid.uuid4().hex}_{col}__"
  1426. df[temp_col] = df[col].map(to_sortable)
  1427. sort_cols.append(temp_col)
  1428. temp_cols.append(temp_col)
  1429. else:
  1430. sort_cols.append(col)
  1431. sorted_df = df.sort_values(sort_cols)
  1432. if temp_cols:
  1433. sorted_df = sorted_df.drop(columns=temp_cols)
  1434. return sorted_df
  1435. def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool:
  1436. """Check if two DataFrames have the same rows.
  1437. Unlike the built-in pandas equals method, this function ignores indices and the
  1438. order of rows. This is useful for testing Ray Data because its interface doesn't
  1439. usually guarantee the order of rows.
  1440. """
  1441. if len(actual) != len(expected):
  1442. return False
  1443. if len(actual) == 0:
  1444. return True
  1445. pd.testing.assert_frame_equal(
  1446. _sort_df(actual).reset_index(drop=True),
  1447. _sort_df(expected).reset_index(drop=True),
  1448. check_dtype=False,
  1449. )
  1450. return True
  1451. def merge_resources_to_ray_remote_args(
  1452. num_cpus: Optional[int],
  1453. num_gpus: Optional[int],
  1454. memory: Optional[int],
  1455. ray_remote_args: Dict[str, Any],
  1456. ) -> Dict[str, Any]:
  1457. """Convert the given resources to Ray remote args.
  1458. Args:
  1459. num_cpus: The number of CPUs to be added to the Ray remote args.
  1460. num_gpus: The number of GPUs to be added to the Ray remote args.
  1461. memory: The memory to be added to the Ray remote args.
  1462. ray_remote_args: The Ray remote args to be merged.
  1463. Returns:
  1464. The converted arguments.
  1465. """
  1466. ray_remote_args = ray_remote_args.copy()
  1467. if num_cpus is not None:
  1468. ray_remote_args["num_cpus"] = num_cpus
  1469. if num_gpus is not None:
  1470. ray_remote_args["num_gpus"] = num_gpus
  1471. if memory is not None:
  1472. ray_remote_args["memory"] = memory
  1473. return ray_remote_args
  1474. @DeveloperAPI
  1475. def infer_compression(path: str) -> Optional[str]:
  1476. import pyarrow as pa
  1477. compression = None
  1478. try:
  1479. # Try to detect compression codec from path.
  1480. compression = pa.Codec.detect(path).name
  1481. except (ValueError, TypeError):
  1482. # Arrow's compression inference on the file path doesn't work for Snappy, so we double-check ourselves.
  1483. import pathlib
  1484. suffix = pathlib.Path(path).suffix
  1485. if suffix and suffix[1:] == "snappy":
  1486. compression = "snappy"
  1487. return compression