| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794 |
- import functools
- import importlib
- import logging
- import os
- import pathlib
- import platform
- import random
- import sys
- import threading
- import time
- import urllib.parse
- import uuid
- from queue import Empty, Full, Queue
- from types import ModuleType
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Generator,
- Iterable,
- Iterator,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
- overload,
- )
- import numpy as np
- import pandas as pd
- # NOTE: pyarrow.fs module needs to be explicitly imported!
- import pyarrow
- import pyarrow.fs
- import ray
- from ray._common.retry import call_with_retry
- from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext
- from ray.util.annotations import DeveloperAPI
- import psutil
- # TypeVar for preserving function/class signatures through decorators
- F = TypeVar("F", bound=Callable[..., Any])
- if TYPE_CHECKING:
- import pandas
- from ray.data._internal.compute import ComputeStrategy
- from ray.data._internal.execution.interfaces import RefBundle
- from ray.data._internal.planner.exchange.sort_task_spec import SortKey
- from ray.data.block import (
- Block,
- BlockMetadataWithSchema,
- Schema,
- UserDefinedFunction,
- )
- from ray.data.datasource import Datasource, Reader
- from ray.util.placement_group import PlacementGroup
- logger = logging.getLogger(__name__)
- KiB = 1024 # bytes
- MiB = 1024 * KiB
- GiB = 1024 * MiB
- SENTINEL = object()
- _LOCAL_SCHEME = "local"
- _EXAMPLE_SCHEME = "example"
- LazyModule = Union[None, bool, ModuleType]
- _pyarrow_dataset: LazyModule = None
- class _OrderedNullSentinel:
- """Sentinel value that sorts greater than any other non-null value.
- NOTE: Semantic of this sentinel is closely mirroring that one of
- ``np.nan`` for the purpose of consistency in handling of
- ``None``s and ``np.nan``s.
- """
- def __eq__(self, other):
- return False
- def __lt__(self, other):
- # not None < _OrderedNullSentinel
- # _OrderedNullSentinel < _OrderedNullSentinel
- # _OrderedNullSentinel < None
- # _OrderedNullSentinel < np.nan
- return isinstance(other, _OrderedNullSentinel) or is_null(other)
- def __le__(self, other):
- # NOTE: This is just a shortened version of
- # self < other or self == other
- return self.__lt__(other)
- def __gt__(self, other):
- return not self.__le__(other)
- def __ge__(self, other):
- return not self.__lt__(other)
- def __hash__(self):
- return id(self)
- NULL_SENTINEL = _OrderedNullSentinel()
- def _lazy_import_pyarrow_dataset() -> LazyModule:
- global _pyarrow_dataset
- if _pyarrow_dataset is None:
- try:
- from pyarrow import dataset as _pyarrow_dataset
- except ModuleNotFoundError:
- # If module is not found, set _pyarrow to False so we won't
- # keep trying to import it on every _lazy_import_pyarrow() call.
- _pyarrow_dataset = False
- return _pyarrow_dataset
- def _check_pyarrow_version():
- ray.data._internal.utils.arrow_utils._check_pyarrow_version()
- def _autodetect_parallelism(
- parallelism: int,
- target_max_block_size: Optional[int],
- ctx: DataContext,
- datasource_or_legacy_reader: Optional[Union["Datasource", "Reader"]] = None,
- mem_size: Optional[int] = None,
- placement_group: Optional["PlacementGroup"] = None,
- avail_cpus: Optional[int] = None,
- ) -> Tuple[int, str, Optional[int]]:
- """Returns parallelism to use and the min safe parallelism to avoid OOMs.
- This detects parallelism using the following heuristics, applied in order:
- 1) We start with the default value of 200. This can be overridden by
- setting the `read_op_min_num_blocks` attribute of
- :class:`~ray.data.context.DataContext`.
- 2) Min block size. If the parallelism would make blocks smaller than this
- threshold, the parallelism is reduced to avoid the overhead of tiny blocks.
- 3) Max block size. If the parallelism would make blocks larger than this
- threshold, the parallelism is increased to avoid OOMs during processing.
- 4) Available CPUs. If the parallelism cannot make use of all the available
- CPUs in the cluster, the parallelism is increased until it can.
- Args:
- parallelism: The user-requested parallelism, or -1 for auto-detection.
- target_max_block_size: The target max block size to
- produce. We pass this separately from the
- DatasetContext because it may be set per-op instead of
- per-Dataset.
- ctx: The current Dataset context to use for configs.
- datasource_or_legacy_reader: The datasource or legacy reader, to be used for
- data size estimation.
- mem_size: If passed, then used to compute the parallelism according to
- target_max_block_size.
- placement_group: The placement group that this Dataset
- will execute inside, if any.
- avail_cpus: Override avail cpus detection (for testing only).
- Returns:
- Tuple of detected parallelism (only if -1 was specified), the reason
- for the detected parallelism (only if -1 was specified), and the estimated
- inmemory size of the dataset.
- """
- min_safe_parallelism = 1
- max_reasonable_parallelism = sys.maxsize
- if mem_size is None and datasource_or_legacy_reader:
- mem_size = datasource_or_legacy_reader.estimate_inmemory_data_size()
- if (
- mem_size is not None
- and not np.isnan(mem_size)
- and target_max_block_size is not None
- ):
- min_safe_parallelism = max(1, int(mem_size / target_max_block_size))
- max_reasonable_parallelism = max(1, int(mem_size / ctx.target_min_block_size))
- reason = ""
- if parallelism < 0:
- if parallelism != -1:
- raise ValueError("`parallelism` must either be -1 or a positive integer.")
- if (
- ctx.min_parallelism is not None
- and ctx.min_parallelism != DEFAULT_READ_OP_MIN_NUM_BLOCKS
- and ctx.read_op_min_num_blocks == DEFAULT_READ_OP_MIN_NUM_BLOCKS
- ):
- logger.warning(
- "``DataContext.min_parallelism`` is deprecated in Ray 2.10. "
- "Please specify ``DataContext.read_op_min_num_blocks`` instead."
- )
- ctx.read_op_min_num_blocks = ctx.min_parallelism
- # Start with 2x the number of cores as a baseline, with a min floor.
- if placement_group is None:
- placement_group = ray.util.get_current_placement_group()
- avail_cpus = avail_cpus or _estimate_avail_cpus(placement_group)
- parallelism = max(
- min(ctx.read_op_min_num_blocks, max_reasonable_parallelism),
- min_safe_parallelism,
- avail_cpus * 2,
- )
- if parallelism == ctx.read_op_min_num_blocks:
- reason = (
- "DataContext.get_current().read_op_min_num_blocks="
- f"{ctx.read_op_min_num_blocks}"
- )
- elif parallelism == max_reasonable_parallelism:
- reason = (
- "output blocks of size at least "
- "DataContext.get_current().target_min_block_size="
- f"{ctx.target_min_block_size / MiB} MiB"
- )
- elif parallelism == min_safe_parallelism:
- # Handle ``None`` (unlimited) gracefully in the log message.
- if ctx.target_max_block_size is None:
- display_val = "unlimited"
- else:
- display_val = f"{ctx.target_max_block_size / MiB} MiB"
- reason = (
- "output blocks of size at most "
- "DataContext.get_current().target_max_block_size="
- f"{display_val}"
- )
- else:
- reason = (
- "parallelism at least twice the available number "
- f"of CPUs ({avail_cpus})"
- )
- logger.debug(
- f"Autodetected parallelism={parallelism} based on "
- f"estimated_available_cpus={avail_cpus} and "
- f"estimated_data_size={mem_size}."
- )
- return parallelism, reason, mem_size
- def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int:
- """Estimates the available CPU parallelism for this Dataset in the cluster.
- If we aren't in a placement group, this is trivially the number of CPUs in the
- cluster. Otherwise, we try to calculate how large the placement group is relative
- to the size of the cluster.
- Args:
- cur_pg: The current placement group, if any.
- """
- cluster_cpus = int(ray.cluster_resources().get("CPU", 1))
- cluster_gpus = int(ray.cluster_resources().get("GPU", 0))
- # If we're in a placement group, we shouldn't assume the entire cluster's
- # resources are available for us to use. Estimate an upper bound on what's
- # reasonable to assume is available for datasets to use.
- if cur_pg:
- pg_cpus = 0
- for bundle in cur_pg.bundle_specs:
- # Calculate the proportion of the cluster this placement group "takes up".
- # Then scale our cluster_cpus proportionally to avoid over-parallelizing
- # if there are many parallel Tune trials using the cluster.
- cpu_fraction = bundle.get("CPU", 0) / max(1, cluster_cpus)
- gpu_fraction = bundle.get("GPU", 0) / max(1, cluster_gpus)
- max_fraction = max(cpu_fraction, gpu_fraction)
- # Over-parallelize by up to a factor of 2, but no more than that. It's
- # preferrable to over-estimate than under-estimate.
- pg_cpus += 2 * int(max_fraction * cluster_cpus)
- return min(cluster_cpus, pg_cpus)
- return cluster_cpus
- def _estimate_available_parallelism() -> int:
- """Estimates the available CPU parallelism for this Dataset in the cluster.
- If we are currently in a placement group, take that into account."""
- cur_pg = ray.util.get_current_placement_group()
- return _estimate_avail_cpus(cur_pg)
- def _warn_on_high_parallelism(requested_parallelism, num_read_tasks):
- available_cpu_slots = ray.available_resources().get("CPU", 1)
- if (
- requested_parallelism
- and num_read_tasks > available_cpu_slots * 4
- and num_read_tasks >= 5000
- ):
- logger.warning(
- f"{WARN_PREFIX} The requested parallelism of {requested_parallelism} "
- "is more than 4x the number of available CPU slots in the cluster of "
- f"{available_cpu_slots}. This can "
- "lead to slowdowns during the data reading phase due to excessive "
- "task creation. Reduce the parallelism to match with the available "
- "CPU slots in the cluster, or set parallelism to -1 for Ray Data "
- "to automatically determine the parallelism. "
- "You can ignore this message if the cluster is expected to autoscale."
- )
- def _check_import(obj, *, module: str, package: str) -> None:
- """Check if a required dependency is installed.
- If `module` can't be imported, this function raises an `ImportError` instructing
- the user to install `package` from PyPI.
- Args:
- obj: The object that has a dependency.
- module: The name of the module to import.
- package: The name of the package on PyPI.
- """
- try:
- importlib.import_module(module)
- except ImportError:
- raise ImportError(
- f"`{obj.__class__.__name__}` depends on '{module}', but Ray Data couldn't "
- f"import it. Install '{module}' by running `pip install {package}`."
- )
- def _resolve_custom_scheme(path: str) -> str:
- """Returns the resolved path if the given path follows a Ray-specific custom
- scheme. Othewise, returns the path unchanged.
- The supported custom schemes are: "local", "example".
- """
- parsed_uri = urllib.parse.urlparse(path)
- if parsed_uri.scheme == _LOCAL_SCHEME:
- path = parsed_uri.netloc + parsed_uri.path
- elif parsed_uri.scheme == _EXAMPLE_SCHEME:
- example_data_path = pathlib.Path(__file__).parent.parent / "examples" / "data"
- path = example_data_path / (parsed_uri.netloc + parsed_uri.path)
- path = str(path.resolve())
- return path
- def _is_local_scheme(paths: Union[str, List[str]]) -> bool:
- """Returns True if the given paths are in local scheme.
- Note: The paths must be in same scheme, i.e. it's invalid and
- will raise error if paths are mixed with different schemes.
- """
- if isinstance(paths, str):
- paths = [paths]
- if isinstance(paths, pathlib.Path):
- paths = [str(paths)]
- elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths):
- raise ValueError("paths must be a path string or a list of path strings.")
- elif len(paths) == 0:
- raise ValueError("Must provide at least one path.")
- num = sum(urllib.parse.urlparse(path).scheme == _LOCAL_SCHEME for path in paths)
- if num > 0 and num < len(paths):
- raise ValueError(
- "The paths must all be local-scheme or not local-scheme, "
- f"but found mixed {paths}"
- )
- return num == len(paths)
- def _truncated_repr(obj: Any) -> str:
- """Utility to return a truncated object representation for error messages."""
- msg = str(obj)
- if len(msg) > 200:
- msg = msg[:200] + "..."
- return msg
- def _insert_doc_at_pattern(
- obj,
- *,
- message: str,
- pattern: str,
- insert_after: bool = True,
- directive: Optional[str] = None,
- skip_matches: int = 0,
- ) -> str:
- if "\n" in message:
- raise ValueError(
- "message shouldn't contain any newlines, since this function will insert "
- f"its own linebreaks when text wrapping: {message}"
- )
- doc = obj.__doc__.strip()
- if not doc:
- doc = ""
- if pattern == "" and insert_after:
- # Empty pattern + insert_after means that we want to append the message to the
- # end of the docstring.
- head = doc
- tail = ""
- else:
- tail = doc
- i = tail.find(pattern)
- skip_matches_left = skip_matches
- while i != -1:
- if insert_after:
- # Set offset to the first character after the pattern.
- offset = i + len(pattern)
- else:
- # Set offset to the first character in the matched line.
- offset = tail[:i].rfind("\n") + 1
- head = tail[:offset]
- tail = tail[offset:]
- skip_matches_left -= 1
- if skip_matches_left <= 0:
- break
- elif not insert_after:
- # Move past the found pattern, since we're skipping it.
- tail = tail[i - offset + len(pattern) :]
- i = tail.find(pattern)
- else:
- raise ValueError(
- f"Pattern {pattern} not found after {skip_matches} skips in docstring "
- f"{doc}"
- )
- # Get indentation of the to-be-inserted text.
- after_lines = list(filter(bool, tail.splitlines()))
- if len(after_lines) > 0:
- lines = after_lines
- else:
- lines = list(filter(bool, reversed(head.splitlines())))
- # Should always have at least one non-empty line in the docstring.
- assert len(lines) > 0
- indent = " " * (len(lines[0]) - len(lines[0].lstrip()))
- # Handle directive.
- message = message.strip("\n")
- if directive is not None:
- base = f"{indent}.. {directive}::\n"
- message = message.replace("\n", "\n" + indent + " " * 4)
- message = base + indent + " " * 4 + message
- else:
- message = indent + message.replace("\n", "\n" + indent)
- # Add two blank lines before/after message, if necessary.
- if insert_after ^ (pattern == "\n\n"):
- # Only two blank lines before message if:
- # 1. Inserting message after pattern and pattern is not two blank lines.
- # 2. Inserting message before pattern and pattern is two blank lines.
- message = "\n\n" + message
- if (not insert_after) ^ (pattern == "\n\n"):
- # Only two blank lines after message if:
- # 1. Inserting message before pattern and pattern is not two blank lines.
- # 2. Inserting message after pattern and pattern is two blank lines.
- message = message + "\n\n"
- # Insert message before/after pattern.
- parts = [head, message, tail]
- # Build new docstring.
- obj.__doc__ = "".join(parts)
- def _consumption_api(
- if_more_than_read: bool = False,
- datasource_metadata: Optional[str] = None,
- extra_condition: Optional[str] = None,
- delegate: Optional[str] = None,
- pattern: str = "Examples:",
- insert_after: bool = False,
- ) -> Callable[[F], F]:
- """Annotate the function with an indication that it's a consumption API, and that it
- will trigger Dataset execution.
- """
- base = (
- " will trigger execution of the lazy transformations performed on "
- "this dataset."
- )
- if delegate:
- message = delegate + base
- elif not if_more_than_read:
- message = "This operation" + base
- else:
- condition = "If this dataset consists of more than a read, "
- if datasource_metadata is not None:
- condition += (
- f"or if the {datasource_metadata} can't be determined from the "
- "metadata provided by the datasource, "
- )
- if extra_condition is not None:
- condition += extra_condition + ", "
- message = condition + "then this operation" + base
- def wrap(obj: F) -> F:
- _insert_doc_at_pattern(
- obj,
- message=message,
- pattern=pattern,
- insert_after=insert_after,
- directive="note",
- )
- return obj
- return wrap
- @overload
- def ConsumptionAPI(obj: F) -> F:
- ...
- @overload
- def ConsumptionAPI(
- *,
- if_more_than_read: bool = False,
- datasource_metadata: Optional[str] = None,
- extra_condition: Optional[str] = None,
- delegate: Optional[str] = None,
- ) -> Callable[[F], F]:
- ...
- def ConsumptionAPI(*args, **kwargs):
- """Annotate the function with an indication that it's a consumption API, and that it
- will trigger Dataset execution.
- """
- if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
- return _consumption_api()(args[0])
- return _consumption_api(*args, **kwargs)
- def _all_to_all_api() -> Callable[[F], F]:
- """Annotate the function with an indication that it's a all to all API, and that it
- is an operation that requires all inputs to be materialized in-memory to execute.
- """
- def wrap(obj: F) -> F:
- _insert_doc_at_pattern(
- obj,
- message=(
- "This operation requires all inputs to be "
- "materialized in object store for it to execute."
- ),
- pattern="Examples:",
- insert_after=False,
- directive="note",
- )
- return obj
- return wrap
- @overload
- def AllToAllAPI(obj: F) -> F:
- ...
- def AllToAllAPI(*args, **kwargs):
- """Annotate the function with an indication that it's a all to all API, and that it
- is an operation that requires all inputs to be materialized in-memory to execute.
- """
- # This should only be used as a decorator for dataset methods.
- assert len(args) == 1 and len(kwargs) == 0 and callable(args[0])
- return _all_to_all_api()(args[0])
- def get_compute_strategy(
- fn: "UserDefinedFunction",
- fn_constructor_args: Optional[Iterable[Any]] = None,
- compute: Optional[Union[str, "ComputeStrategy"]] = None,
- concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
- ) -> "ComputeStrategy":
- """Get `ComputeStrategy` based on the function or class, and concurrency
- information.
- Args:
- fn: The function or generator to apply to a record batch, or a class type
- that can be instantiated to create such a callable.
- fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
- compute: Either "tasks" (default) to use Ray Tasks or an
- :class:`~ray.data.ActorPoolStrategy` to use an autoscaling actor pool.
- concurrency: The number of Ray workers to use concurrently.
- Returns:
- The `ComputeStrategy` for execution.
- """
- # Lazily import these objects to avoid circular imports.
- from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy
- from ray.data.block import CallableClass
- if isinstance(fn, CallableClass):
- is_callable_class = True
- else:
- # TODO(chengsu): disallow object that is not a function. For example,
- # An object instance of class often indicates a bug in user code.
- is_callable_class = False
- if fn_constructor_args is not None:
- raise ValueError(
- "``fn_constructor_args`` can only be specified if providing a "
- f"callable class instance for ``fn``, but got: {fn}."
- )
- if compute is not None:
- if is_callable_class and (
- compute == "tasks" or isinstance(compute, TaskPoolStrategy)
- ):
- raise ValueError(
- f"You specified the callable class {fn} as your UDF with the compute "
- f"{compute}, but Ray Data can't schedule callable classes with the task "
- f"pool strategy. To fix this error, pass an ActorPoolStrategy to compute or "
- f"None to use the default compute strategy."
- )
- elif not is_callable_class and (
- compute == "actors" or isinstance(compute, ActorPoolStrategy)
- ):
- raise ValueError(
- f"You specified the function {fn} as your UDF with the compute "
- f"{compute}, but Ray Data can't schedule regular functions with the actor "
- f"pool strategy. To fix this error, pass a TaskPoolStrategy to compute or "
- f"None to use the default compute strategy."
- )
- return compute
- elif concurrency is not None:
- # Legacy code path to support `concurrency` argument.
- logger.warning(
- "The argument ``concurrency`` is deprecated in Ray 2.51. Please specify "
- "argument ``compute`` instead. For more information, see "
- "https://docs.ray.io/en/master/data/transforming-data.html#"
- "stateful-transforms."
- )
- if isinstance(concurrency, tuple):
- # Validate tuple length and that all elements are integers
- if len(concurrency) not in (2, 3) or not all(
- isinstance(c, int) for c in concurrency
- ):
- raise ValueError(
- "``concurrency`` is expected to be set as a tuple of "
- f"integers, but got: {concurrency}."
- )
- # Check if function is callable class (common validation)
- if not is_callable_class:
- raise ValueError(
- "``concurrency`` is set as a tuple of integers, but ``fn`` "
- f"is not a callable class: {fn}. Use ``concurrency=n`` to "
- "control maximum number of workers to use."
- )
- # Create ActorPoolStrategy based on tuple length
- if len(concurrency) == 2:
- return ActorPoolStrategy(
- min_size=concurrency[0], max_size=concurrency[1]
- )
- else: # len(concurrency) == 3
- return ActorPoolStrategy(
- min_size=concurrency[0],
- max_size=concurrency[1],
- initial_size=concurrency[2],
- )
- elif isinstance(concurrency, int):
- if is_callable_class:
- return ActorPoolStrategy(size=concurrency)
- else:
- return TaskPoolStrategy(size=concurrency)
- else:
- raise ValueError(
- "``concurrency`` is expected to be set as an integer or a "
- f"tuple of integers, but got: {concurrency}."
- )
- else:
- if is_callable_class:
- return ActorPoolStrategy(min_size=1, max_size=None)
- else:
- return TaskPoolStrategy()
- def capfirst(s: str):
- """Capitalize the first letter of a string
- Args:
- s: String to capitalize
- Returns:
- Capitalized string
- """
- return s[0].upper() + s[1:]
- def capitalize(s: str):
- """Capitalize a string, removing '_' and keeping camelcase.
- Args:
- s: String to capitalize
- Returns:
- Capitalized string with no underscores.
- """
- return "".join(capfirst(x) for x in s.split("_"))
- def pandas_df_to_arrow_block(
- df: "pandas.DataFrame",
- ) -> Tuple["Block", "BlockMetadataWithSchema"]:
- from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema
- block = BlockAccessor.for_block(df).to_arrow()
- stats = BlockExecStats.builder()
- return block, BlockMetadataWithSchema.from_block(block, stats=stats.build())
- def ndarray_to_block(
- ndarray: np.ndarray, ctx: DataContext
- ) -> Tuple["Block", "BlockMetadataWithSchema"]:
- from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema
- DataContext._set_current(ctx)
- stats = BlockExecStats.builder()
- block = BlockAccessor.batch_to_block({"data": ndarray})
- return block, BlockMetadataWithSchema.from_block(block, stats=stats.build())
- def get_table_block_metadata_schema(
- table: Union["pyarrow.Table", "pandas.DataFrame"],
- ) -> "BlockMetadataWithSchema":
- from ray.data.block import BlockExecStats, BlockMetadataWithSchema
- stats = BlockExecStats.builder()
- return BlockMetadataWithSchema.from_block(table, stats=stats.build())
- def unify_block_metadata_schema(
- block_metadata_with_schemas: List["BlockMetadataWithSchema"],
- ) -> Optional["Schema"]:
- """For the input list of BlockMetadata, return a unified schema of the
- corresponding blocks. If the metadata have no valid schema, returns None.
- Args:
- block_metadata_with_schemas: List of BlockMetadata to unify
- Returns:
- A unified schema of the input list of schemas, or None if no valid schemas
- are provided.
- """
- # Some blocks could be empty, in which case we cannot get their schema.
- # TODO(ekl) validate schema is the same across different blocks.
- # First check if there are blocks with computed schemas, then unify
- # valid schemas from all such blocks.
- schemas_to_unify = []
- for m in block_metadata_with_schemas:
- if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
- schemas_to_unify.append(m.schema)
- return unify_schemas_with_validation(schemas_to_unify)
- def unify_schemas_with_validation(
- schemas_to_unify: Iterable["Schema"],
- ) -> Optional["Schema"]:
- if schemas_to_unify:
- from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas
- # Check valid pyarrow installation before attempting schema unification
- try:
- import pyarrow as pa
- except ImportError:
- pa = None
- # If the result contains PyArrow schemas, unify them
- if pa is not None and all(isinstance(s, pa.Schema) for s in schemas_to_unify):
- return unify_schemas(schemas_to_unify, promote_types=True)
- # Otherwise, if the resulting schemas are simple types (e.g. int),
- # return the first schema.
- return schemas_to_unify[0]
- return None
- def unify_ref_bundles_schema(
- ref_bundles: List["RefBundle"],
- ) -> Optional["Schema"]:
- schemas_to_unify = []
- for bundle in ref_bundles:
- if bundle.schema is not None and (
- bundle.num_rows() is None or bundle.num_rows() > 0
- ):
- schemas_to_unify.append(bundle.schema)
- return unify_schemas_with_validation(schemas_to_unify)
- def find_partition_index(
- table: Union["pyarrow.Table", "pandas.DataFrame"],
- desired: Tuple[Union[int, float]],
- sort_key: "SortKey",
- ) -> int:
- """For the given block, find the index where the desired value should be
- added, to maintain sorted order.
- We do this by iterating over each column, starting with the primary sort key,
- and binary searching for the desired value in the column. Each binary search
- shortens the "range" of indices (represented by ``left`` and ``right``, which
- are indices of rows) where the desired value could be inserted.
- Args:
- table: The block to search in.
- desired: A single tuple representing the boundary to partition at.
- ``len(desired)`` must be less than or equal to the number of columns
- being sorted.
- sort_key: The sort key to use for sorting, providing the columns to be
- sorted and their directions.
- Returns:
- The index where the desired value should be inserted to maintain sorted
- order.
- """
- columns = sort_key.get_columns()
- descending = sort_key.get_descending()
- left, right = 0, len(table)
- for i in range(len(desired)):
- if left == right:
- return right
- col_name = columns[i]
- col_vals = table[col_name].to_numpy()[left:right]
- desired_val = desired[i]
- # Handle null values - replace them with sentinel values
- if desired_val is None:
- desired_val = NULL_SENTINEL
- prevleft = left
- if descending[i] is True:
- # ``np.searchsorted`` expects the array to be sorted in ascending
- # order, so we pass ``sorter``, which is an array of integer indices
- # that sort ``col_vals`` into ascending order. The returned index
- # is an index into the ascending order of ``col_vals``, so we need
- # to subtract it from ``len(col_vals)`` to get the index in the
- # original descending order of ``col_vals``.
- sorter = np.arange(len(col_vals) - 1, -1, -1)
- left = prevleft + (
- len(col_vals)
- - np.searchsorted(
- col_vals,
- desired_val,
- side="right",
- sorter=sorter,
- )
- )
- right = prevleft + (
- len(col_vals)
- - np.searchsorted(
- col_vals,
- desired_val,
- side="left",
- sorter=sorter,
- )
- )
- else:
- left = prevleft + np.searchsorted(col_vals, desired_val, side="left")
- right = prevleft + np.searchsorted(col_vals, desired_val, side="right")
- return right if descending[0] is True else left
- def get_attribute_from_class_name(class_name: str) -> Any:
- """Get Python attribute from the provided class name.
- The caller needs to make sure the provided class name includes
- full module name, and can be imported successfully.
- """
- from importlib import import_module
- paths = class_name.split(".")
- if len(paths) < 2:
- raise ValueError(f"Cannot create object from {class_name}.")
- module_name = ".".join(paths[:-1])
- attribute_name = paths[-1]
- return getattr(import_module(module_name), attribute_name)
- T = TypeVar("T")
- U = TypeVar("U")
- class _InterruptibleQueue(Queue):
- """Extension of Python's `queue.Queue` providing ability to get interrupt its
- method callers in other threads"""
- INTERRUPTION_CHECK_FREQUENCY_SEC = 0.5
- def __init__(
- self, max_size: int, interrupted_event: Optional[threading.Event] = None
- ):
- super().__init__(maxsize=max_size)
- self._interrupted_event = interrupted_event or threading.Event()
- def get(self, block=True, timeout=None):
- if not block or timeout is not None:
- return super().get(block, timeout)
- # In case when the call is blocking and no timeout is specified (ie blocking
- # indefinitely) we apply the following protocol to make it interruptible:
- #
- # 1. `Queue.get` is invoked w/ 500ms timeout
- # 2. `Empty` exception is intercepted (will be raised upon timeout elapsing)
- # 3. If interrupted flag is set `InterruptedError` is raised
- # 4. Otherwise, protocol retried (until interrupted or queue
- # becoming non-empty)
- while True:
- if self._interrupted_event.is_set():
- raise InterruptedError()
- try:
- return super().get(
- block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
- )
- except Empty:
- pass
- def put(self, item, block=True, timeout=None):
- if not block or timeout is not None:
- super().put(item, block, timeout)
- return
- # In case when the call is blocking and no timeout is specified (ie blocking
- # indefinitely) we apply the following protocol to make it interruptible:
- #
- # 1. `Queue.pet` is invoked w/ 500ms timeout
- # 2. `Full` exception is intercepted (will be raised upon timeout elapsing)
- # 3. If interrupted flag is set `InterruptedError` is raised
- # 4. Otherwise, protocol retried (until interrupted or queue
- # becomes non-full)
- while True:
- if self._interrupted_event.is_set():
- raise InterruptedError()
- try:
- super().put(
- item, block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC
- )
- return
- except Full:
- pass
- def make_async_gen(
- base_iterator: Iterator[T],
- fn: Callable[[Iterator[T]], Iterator[U]],
- preserve_ordering: bool,
- num_workers: int = 1,
- buffer_size: int = 1,
- ) -> Generator[U, None, None]:
- """Returns a generator (iterator) mapping items from the
- provided iterator applying provided transformation in parallel (using a
- thread-pool).
- NOTE: There are some important constraints that needs to be carefully
- understood before using this method
- 1. If `preserve_ordering` is True
- a. This method would unroll input iterator eagerly (irrespective
- of the speed of resulting generator being consumed). This is necessary
- as we can not guarantee liveness of the algorithm AND preserving of the
- original ordering at the same time.
- b. Resulting ordering of the output will "match" ordering of the input, ie
- that:
- iterator = [A1, A2, ... An]
- output iterator = [map(A1), map(A2), ..., map(An)]
- 2. If `preserve_ordering` is False
- a. No more than `num_workers * (queue_buffer_size + 1)` elements will be
- fetched from the iterator
- b. Resulting ordering of the output is unspecified (and is
- non-deterministic)
- Args:
- base_iterator: Iterator yielding elements to map
- fn: Transformation to apply to each element
- preserve_ordering: Whether ordering has to be preserved
- num_workers: The number of threads to use in the threadpool (defaults to 1)
- buffer_size: Number of objects to be buffered in its input/output
- queues (per queue; defaults to 2). Total number of objects held
- in memory could be calculated as:
- num_workers * buffer_size * 2 (input and output)
- Returns:
- An generator (iterator) of the elements corresponding to the source
- elements mapped by provided transformation (while *preserving the ordering*)
- """
- gen_id = random.randint(0, 2**31 - 1)
- if num_workers < 1:
- raise ValueError("Size of threadpool must be at least 1.")
- # Signal handler used to interrupt workers when terminating
- interrupted_event = threading.Event()
- # To apply transformations to elements in parallel *and* preserve the ordering
- # following invariants are established:
- # - Every worker is handled by standalone thread
- # - Every worker is assigned an input and an output queue
- #
- # And following protocol is implemented:
- # - Filling worker traverses input iterator round-robin'ing elements across
- # the input queues (in order!)
- # - Transforming workers traverse respective input queue in-order: de-queueing
- # element, applying transformation and enqueuing the result into the output
- # queue
- # - Generator (returned from this method) traverses output queues (in the same
- # order as input queues) dequeues 1 mapped element at a time from each output
- # queue and yields it
- #
- # However, in case when we're preserving the ordering we can not enforce the input
- # queue size as this could result in deadlocks since transformations could be
- # producing sequences of arbitrary length.
- #
- # Check `test_make_async_gen_varying_seq_length_stress_test` for more context on
- # this problem.
- if preserve_ordering:
- input_queue_buf_size = -1
- num_input_queues = num_workers
- else:
- input_queue_buf_size = (buffer_size + 1) * num_workers
- num_input_queues = 1
- input_queues = [
- _InterruptibleQueue(input_queue_buf_size, interrupted_event)
- for _ in range(num_input_queues)
- ]
- output_queues = [
- _InterruptibleQueue(buffer_size, interrupted_event) for _ in range(num_workers)
- ]
- # Filling worker
- def _run_filling_worker():
- try:
- # First, round-robin elements from the iterator into
- # corresponding input queues (one by one)
- for idx, item in enumerate(base_iterator):
- input_queues[idx % num_input_queues].put(item)
- # NOTE: We have to Enqueue sentinel objects for every transforming
- # worker:
- # - In case of preserving order of ``num_queues`` == ``num_workers``
- # we will enqueue 1 sentinel per queue
- # - In case of NOT preserving order all ``num_workers`` sentinels
- # will be enqueued into a single queue
- for idx in range(num_workers):
- input_queues[idx % num_input_queues].put(SENTINEL)
- except InterruptedError:
- pass
- except Exception as e:
- logger.warning("Caught exception in filling worker!", exc_info=e)
- # In case of filling worker encountering an exception we have to propagate
- # it back to the (main) iterating thread. To achieve that we're traversing
- # output queues *backwards* relative to the order of iterator-thread such
- # that they are more likely to meet w/in a single iteration.
- for output_queue in reversed(output_queues):
- output_queue.put(e)
- # Transforming worker
- def _run_transforming_worker(input_queue, output_queue):
- try:
- # Create iterator draining the queue, until it receives sentinel
- #
- # NOTE: `queue.get` is blocking!
- input_queue_iter = iter(input_queue.get, SENTINEL)
- for result in fn(input_queue_iter):
- # Enqueue result of the transformation
- output_queue.put(result)
- # Enqueue sentinel (to signal that transformations are completed)
- output_queue.put(SENTINEL)
- except InterruptedError:
- pass
- except Exception as e:
- logger.warning("Caught exception in transforming worker!", exc_info=e)
- # NOTE: In this case we simply enqueue the exception rather than
- # interrupting
- output_queue.put(e)
- # Start workers threads
- filling_worker_thread = threading.Thread(
- target=_run_filling_worker,
- name=f"map_tp_filling_worker-{gen_id}",
- daemon=True,
- )
- filling_worker_thread.start()
- transforming_worker_threads = [
- threading.Thread(
- target=_run_transforming_worker,
- name=f"map_tp_transforming_worker-{gen_id}-{idx}",
- args=(input_queues[idx % num_input_queues], output_queues[idx]),
- daemon=True,
- )
- for idx in range(num_workers)
- ]
- for t in transforming_worker_threads:
- t.start()
- # Use main thread to yield output batches
- try:
- # Keep track of remaining non-empty output queues
- remaining_output_queues = output_queues
- while len(remaining_output_queues) > 0:
- # To provide deterministic ordering of the produced iterator we rely
- # on the following invariants:
- #
- # - Elements from the original iterator are round-robin'd into
- # input queues (in order)
- # - Individual workers drain their respective input queues populating
- # output queues with the results of applying transformation to the
- # original item (and hence preserving original ordering of the input
- # queue)
- # - To yield from the generator output queues are traversed in the same
- # order and one single element is dequeued (in a blocking way!) at a
- # time from every individual output queue
- #
- empty_queues = []
- # At every iteration only remaining non-empty queues
- # are traversed (to prevent blocking on exhausted queue)
- for output_queue in remaining_output_queues:
- # NOTE: This is blocking!
- item = output_queue.get()
- if isinstance(item, Exception):
- raise item
- if item is SENTINEL:
- empty_queues.append(output_queue)
- else:
- yield item
- if empty_queues:
- remaining_output_queues = [
- q for q in remaining_output_queues if q not in empty_queues
- ]
- finally:
- # Set flag to interrupt workers (to make sure no dangling
- # threads holding the objects are left behind)
- #
- # NOTE: Interrupted event is set to interrupt the running threads
- # that might be blocked otherwise waiting on inputs from respective
- # queues. However, even though we're interrupting the threads we can't
- # guarantee that threads will be interrupted in time (as this is
- # dependent on Python's GC finalizer to close the generator by raising
- # `GeneratorExit`) and hence we can't join on either filling or
- # transforming workers.
- interrupted_event.set()
- class RetryingContextManager:
- def __init__(
- self,
- f: pyarrow.NativeFile,
- context: DataContext,
- max_attempts: int = 10,
- max_backoff_s: int = 32,
- ):
- self._f = f
- self._data_context = context
- self._max_attempts = max_attempts
- self._max_backoff_s = max_backoff_s
- def __repr__(self):
- return f"<{self.__class__.__name__} fs={self.handler.unwrap()}>"
- def _retry_operation(self, operation: Callable, description: str):
- """Execute an operation with retries."""
- return call_with_retry(
- operation,
- description=description,
- match=self._data_context.retried_io_errors,
- max_attempts=self._max_attempts,
- max_backoff_s=self._max_backoff_s,
- )
- def __enter__(self):
- return self._retry_operation(self._f.__enter__, "enter file context")
- def __exit__(self, exc_type, exc_value, traceback):
- self._retry_operation(
- lambda: self._f.__exit__(exc_type, exc_value, traceback),
- "exit file context",
- )
- class RetryingPyFileSystem(pyarrow.fs.PyFileSystem):
- def __init__(self, handler: "RetryingPyFileSystemHandler"):
- if not isinstance(handler, RetryingPyFileSystemHandler):
- assert ValueError("handler must be a RetryingPyFileSystemHandler")
- super().__init__(handler)
- @property
- def retryable_errors(self) -> List[str]:
- return self.handler._retryable_errors
- def unwrap(self):
- return self.handler.unwrap()
- @classmethod
- def wrap(
- cls,
- fs: "pyarrow.fs.FileSystem",
- retryable_errors: List[str],
- max_attempts: int = 10,
- max_backoff_s: int = 32,
- ):
- if isinstance(fs, RetryingPyFileSystem):
- return fs
- handler = RetryingPyFileSystemHandler(
- fs, retryable_errors, max_attempts, max_backoff_s
- )
- return cls(handler)
- def __reduce__(self):
- # Serialization of this class breaks for some reason without this
- return (self.__class__, (self.handler,))
- @classmethod
- def __setstate__(cls, state):
- # Serialization of this class breaks for some reason without this
- return cls(*state)
- class RetryingPyFileSystemHandler(pyarrow.fs.FileSystemHandler):
- """Wrapper for filesystem objects that adds retry functionality for file operations.
- This class wraps any filesystem object and adds automatic retries for common
- file operations that may fail transiently.
- """
- def __init__(
- self,
- fs: "pyarrow.fs.FileSystem",
- retryable_errors: List[str] = tuple(),
- max_attempts: int = 10,
- max_backoff_s: int = 32,
- ):
- """Initialize the retrying filesystem wrapper.
- Args:
- fs: The underlying filesystem to wrap
- context: DataContext for retry settings
- max_attempts: Maximum number of retry attempts
- max_backoff_s: Maximum backoff time in seconds
- """
- assert not isinstance(
- fs, RetryingPyFileSystem
- ), "Cannot wrap a RetryingPyFileSystem"
- self._fs = fs
- self._retryable_errors = retryable_errors
- self._max_attempts = max_attempts
- self._max_backoff_s = max_backoff_s
- def _retry_operation(self, operation: Callable, description: str):
- """Execute an operation with retries."""
- return call_with_retry(
- operation,
- description=description,
- match=self._retryable_errors,
- max_attempts=self._max_attempts,
- max_backoff_s=self._max_backoff_s,
- )
- def unwrap(self):
- return self._fs
- def copy_file(self, src: str, dest: str):
- """Copy a file."""
- return self._retry_operation(
- lambda: self._fs.copy_file(src, dest), f"copy file from {src} to {dest}"
- )
- def create_dir(self, path: str, recursive: bool):
- """Create a directory and subdirectories."""
- return self._retry_operation(
- lambda: self._fs.create_dir(path, recursive=recursive),
- f"create directory {path}",
- )
- def delete_dir(self, path: str):
- """Delete a directory and its contents, recursively."""
- return self._retry_operation(
- lambda: self._fs.delete_dir(path), f"delete directory {path}"
- )
- def delete_dir_contents(self, path: str, missing_dir_ok: bool = False):
- """Delete a directory's contents, recursively."""
- return self._retry_operation(
- lambda: self._fs.delete_dir_contents(path, missing_dir_ok=missing_dir_ok),
- f"delete directory contents {path}",
- )
- def delete_file(self, path: str):
- """Delete a file."""
- return self._retry_operation(
- lambda: self._fs.delete_file(path), f"delete file {path}"
- )
- def delete_root_dir_contents(self):
- return self._retry_operation(
- lambda: self._fs.delete_dir_contents("/", accept_root_dir=True),
- "delete root dir contents",
- )
- def equals(self, other: "pyarrow.fs.FileSystem") -> bool:
- """Test if this filesystem equals another."""
- return self._fs.equals(other)
- def get_file_info(self, paths: List[str]):
- """Get info for the given files."""
- return self._retry_operation(
- lambda: self._fs.get_file_info(paths),
- f"get file info for {paths}",
- )
- def get_file_info_selector(self, selector):
- return self._retry_operation(
- lambda: self._fs.get_file_info(selector),
- f"get file info for {selector}",
- )
- def get_type_name(self):
- return "RetryingPyFileSystem"
- def move(self, src: str, dest: str):
- """Move / rename a file or directory."""
- return self._retry_operation(
- lambda: self._fs.move(src, dest), f"move from {src} to {dest}"
- )
- def normalize_path(self, path: str) -> str:
- """Normalize filesystem path."""
- return self._retry_operation(
- lambda: self._fs.normalize_path(path), f"normalize path {path}"
- )
- def open_append_stream(
- self,
- path: str,
- metadata=None,
- ) -> "pyarrow.NativeFile":
- """Open an output stream for appending.
- Compression is disabled in this method because it is handled in the
- PyFileSystem abstract class.
- """
- return self._retry_operation(
- lambda: self._fs.open_append_stream(
- path,
- compression=None,
- metadata=metadata,
- ),
- f"open append stream for {path}",
- )
- def open_input_stream(
- self,
- path: str,
- ) -> "pyarrow.NativeFile":
- """Open an input stream for sequential reading.
- Compression is disabled in this method because it is handled in the
- PyFileSystem abstract class.
- """
- return self._retry_operation(
- lambda: self._fs.open_input_stream(path, compression=None),
- f"open input stream for {path}",
- )
- def open_output_stream(
- self,
- path: str,
- metadata=None,
- ) -> "pyarrow.NativeFile":
- """Open an output stream for sequential writing."
- Compression is disabled in this method because it is handled in the
- PyFileSystem abstract class.
- """
- return self._retry_operation(
- lambda: self._fs.open_output_stream(
- path,
- compression=None,
- metadata=metadata,
- ),
- f"open output stream for {path}",
- )
- def open_input_file(self, path: str) -> "pyarrow.NativeFile":
- """Open an input file for random access reading."""
- return self._retry_operation(
- lambda: self._fs.open_input_file(path), f"open input file {path}"
- )
- def iterate_with_retry(
- iterable_factory: Callable[[], Iterable],
- description: str,
- *,
- match: Optional[List[str]] = None,
- max_attempts: int = 10,
- max_backoff_s: int = 32,
- ) -> Any:
- """Iterate through an iterable with retries.
- If the iterable raises an exception, this function recreates and re-iterates
- through the iterable, while skipping the items that have already been yielded.
- Args:
- iterable_factory: A no-argument function that creates the iterable.
- match: A list of strings to match in the exception message. If ``None``, any
- error is retried.
- description: An imperitive description of the function being retried. For
- example, "open the file".
- max_attempts: The maximum number of attempts to retry.
- max_backoff_s: The maximum number of seconds to backoff.
- """
- assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."
- num_items_yielded = 0
- for attempt in range(max_attempts):
- try:
- iterable = iterable_factory()
- for item_index, item in enumerate(iterable):
- if item_index < num_items_yielded:
- # Skip items that have already been yielded.
- continue
- num_items_yielded += 1
- yield item
- return
- except Exception as e:
- is_retryable = match is None or any(pattern in str(e) for pattern in match)
- if is_retryable and attempt + 1 < max_attempts:
- # Retry with binary expoential backoff with random jitter.
- backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random()
- logger.debug(
- f"Retrying {attempt+1} attempts to {description} "
- f"after {backoff} seconds."
- )
- time.sleep(backoff)
- else:
- raise e from None
- def convert_bytes_to_human_readable_str(num_bytes: int) -> str:
- if num_bytes >= 1e9:
- num_bytes_str = f"{round(num_bytes / 1e9)}GB"
- elif num_bytes >= 1e6:
- num_bytes_str = f"{round(num_bytes / 1e6)}MB"
- else:
- num_bytes_str = f"{round(num_bytes / 1e3)}KB"
- return num_bytes_str
- def _validate_rows_per_file_args(
- *,
- num_rows_per_file: Optional[int] = None,
- min_rows_per_file: Optional[int] = None,
- max_rows_per_file: Optional[int] = None,
- ) -> Tuple[Optional[int], Optional[int]]:
- """Helper method to validate and handle rows per file arguments.
- Args:
- num_rows_per_file: Deprecated parameter for number of rows per file
- min_rows_per_file: New parameter for minimum rows per file
- max_rows_per_file: New parameter for maximum rows per file
- Returns:
- A tuple of (effective_min_rows_per_file, effective_max_rows_per_file)
- """
- if num_rows_per_file is not None:
- import warnings
- warnings.warn(
- "`num_rows_per_file` is deprecated and will be removed in a future release. "
- "Use `min_rows_per_file` instead.",
- DeprecationWarning,
- stacklevel=3,
- )
- if min_rows_per_file is not None:
- raise ValueError(
- "Cannot specify both `num_rows_per_file` and `min_rows_per_file`. "
- "Use `min_rows_per_file` as `num_rows_per_file` is deprecated."
- )
- min_rows_per_file = num_rows_per_file
- # Validate max_rows_per_file
- if max_rows_per_file is not None and max_rows_per_file <= 0:
- raise ValueError("max_rows_per_file must be a positive integer")
- # Validate min_rows_per_file
- if min_rows_per_file is not None and min_rows_per_file <= 0:
- raise ValueError("min_rows_per_file must be a positive integer")
- # Validate that max >= min if both are specified
- if (
- min_rows_per_file is not None
- and max_rows_per_file is not None
- and min_rows_per_file > max_rows_per_file
- ):
- raise ValueError(
- f"min_rows_per_file ({min_rows_per_file}) cannot be greater than "
- f"max_rows_per_file ({max_rows_per_file})"
- )
- return min_rows_per_file, max_rows_per_file
- def is_nan(value) -> bool:
- """Returns true if provide value is ``np.nan``"""
- try:
- return isinstance(value, float) and np.isnan(value)
- except TypeError:
- return False
- def is_null(value: Any) -> bool:
- """This generalization of ``is_nan`` util qualifying both None and np.nan
- as null values"""
- return value is None or is_nan(value)
- def keys_equal(keys1, keys2):
- if len(keys1) != len(keys2):
- return False
- for k1, k2 in zip(keys1, keys2):
- if not ((is_nan(k1) and is_nan(k2)) or k1 == k2):
- return False
- return True
- def get_total_obj_store_mem_on_node() -> int:
- """Return the total object store memory on the current node.
- This function incurs an RPC. Use it cautiously.
- """
- node_id = ray.get_runtime_context().get_node_id()
- total_resources_per_node = ray._private.state.total_resources_per_node()
- assert (
- node_id in total_resources_per_node
- ), f"Expected node '{node_id}' to be in resources: {total_resources_per_node}"
- return total_resources_per_node[node_id]["object_store_memory"]
- class MemoryProfiler:
- """A context manager that polls the USS of the current process.
- This class approximates the max USS by polling memory and subtracting the amount
- of shared memory from the resident set size (RSS). It's not a
- perfect estimate (it can underestimate, e.g., if you use Torch tensors), but
- estimating the USS is much cheaper than computing the actual USS.
- .. warning::
- This class only works with Linux. If you use it on another platform,
- `estimate_max_uss` always returns ``None``.
- Example:
- .. testcode::
- with MemoryProfiler(poll_interval_s=1.0) as profiler:
- for i in range(10):
- ... # Your code here
- print(f"Max USS: {profiler.estimate_max_uss()}")
- profiler.reset()
- """
- def __init__(self, poll_interval_s: Optional[float]):
- """
- Args:
- poll_interval_s: The interval to poll the USS of the process. If `None`,
- this class won't poll the USS.
- """
- self._poll_interval_s = poll_interval_s
- self._process = psutil.Process(os.getpid())
- self._max_uss = None
- self._max_uss_lock = threading.Lock()
- self._uss_poll_thread = None
- self._stop_uss_poll_event = None
- def __repr__(self):
- return f"MemoryProfiler(poll_interval_s={self._poll_interval_s})"
- def __enter__(self):
- if self._can_estimate_uss() and self._poll_interval_s is not None:
- (
- self._uss_poll_thread,
- self._stop_uss_poll_event,
- ) = self._start_uss_poll_thread()
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self._uss_poll_thread is not None:
- self._stop_uss_poll_thread()
- def estimate_max_uss(self) -> Optional[int]:
- """Get an estimate of the max USS of the current process.
- Returns:
- An estimate of the max USS of the process in bytes, or ``None`` if an
- estimate isn't available.
- """
- if not self._can_estimate_uss():
- assert self._max_uss is None
- return None
- with self._max_uss_lock:
- if self._max_uss is None:
- self._max_uss = self._estimate_uss()
- else:
- self._max_uss = max(self._max_uss, self._estimate_uss())
- assert self._max_uss is not None
- return self._max_uss
- def reset(self):
- with self._max_uss_lock:
- self._max_uss = None
- def _start_uss_poll_thread(self) -> Tuple[threading.Thread, threading.Event]:
- assert self._poll_interval_s is not None
- assert self._can_estimate_uss()
- stop_event = threading.Event()
- def poll_uss():
- while not stop_event.is_set():
- with self._max_uss_lock:
- if self._max_uss is None:
- self._max_uss = self._estimate_uss()
- else:
- self._max_uss = max(self._max_uss, self._estimate_uss())
- stop_event.wait(self._poll_interval_s)
- thread = threading.Thread(target=poll_uss, daemon=True)
- thread.start()
- return thread, stop_event
- def _stop_uss_poll_thread(self):
- if self._stop_uss_poll_event is not None:
- self._stop_uss_poll_event.set()
- self._uss_poll_thread.join()
- def _estimate_uss(self) -> int:
- assert self._can_estimate_uss()
- memory_info = self._process.memory_info()
- # Estimate the USS (the amount of memory that'd be free if we killed the
- # process right now) as the difference between the RSS (total physical memory)
- # and amount of shared physical memory.
- return memory_info.rss - memory_info.shared
- @staticmethod
- @functools.cache
- def _can_estimate_uss() -> bool:
- # MacOS and Windows don't have the 'shared' attribute of `memory_info()`.
- return platform.system() == "Linux"
- def unzip(data: List[Tuple[Any, ...]]) -> Tuple[List[Any], ...]:
- """Unzips a list of tuples into a tuple of lists
- Args:
- data: A list of tuples to unzip.
- Returns:
- A tuple of lists, where each list corresponds to one element of the tuples in
- the input list.
- """
- return tuple(map(list, zip(*data)))
- def _sort_df(df: pd.DataFrame) -> pd.DataFrame:
- """Sort DataFrame by columns and rows, and also handle unhashable types."""
- df = df.copy()
- def to_sortable(x):
- if isinstance(x, (list, np.ndarray)):
- return tuple(to_sortable(i) for i in x)
- if isinstance(x, dict):
- return tuple(sorted((k, to_sortable(v)) for k, v in x.items()))
- return x
- sort_cols = []
- temp_cols = []
- # Sort by all columns to ensure deterministic order.
- columns = sorted(df.columns)
- for col in columns:
- if df[col].dtype == "object":
- # Create a temporary column for sorting to handle unhashable types.
- # Use UUID to avoid collisions with existing column names.
- temp_col = f"__sort_proxy_{uuid.uuid4().hex}_{col}__"
- df[temp_col] = df[col].map(to_sortable)
- sort_cols.append(temp_col)
- temp_cols.append(temp_col)
- else:
- sort_cols.append(col)
- sorted_df = df.sort_values(sort_cols)
- if temp_cols:
- sorted_df = sorted_df.drop(columns=temp_cols)
- return sorted_df
- def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool:
- """Check if two DataFrames have the same rows.
- Unlike the built-in pandas equals method, this function ignores indices and the
- order of rows. This is useful for testing Ray Data because its interface doesn't
- usually guarantee the order of rows.
- """
- if len(actual) != len(expected):
- return False
- if len(actual) == 0:
- return True
- pd.testing.assert_frame_equal(
- _sort_df(actual).reset_index(drop=True),
- _sort_df(expected).reset_index(drop=True),
- check_dtype=False,
- )
- return True
- def merge_resources_to_ray_remote_args(
- num_cpus: Optional[int],
- num_gpus: Optional[int],
- memory: Optional[int],
- ray_remote_args: Dict[str, Any],
- ) -> Dict[str, Any]:
- """Convert the given resources to Ray remote args.
- Args:
- num_cpus: The number of CPUs to be added to the Ray remote args.
- num_gpus: The number of GPUs to be added to the Ray remote args.
- memory: The memory to be added to the Ray remote args.
- ray_remote_args: The Ray remote args to be merged.
- Returns:
- The converted arguments.
- """
- ray_remote_args = ray_remote_args.copy()
- if num_cpus is not None:
- ray_remote_args["num_cpus"] = num_cpus
- if num_gpus is not None:
- ray_remote_args["num_gpus"] = num_gpus
- if memory is not None:
- ray_remote_args["memory"] = memory
- return ray_remote_args
- @DeveloperAPI
- def infer_compression(path: str) -> Optional[str]:
- import pyarrow as pa
- compression = None
- try:
- # Try to detect compression codec from path.
- compression = pa.Codec.detect(path).name
- except (ValueError, TypeError):
- # Arrow's compression inference on the file path doesn't work for Snappy, so we double-check ourselves.
- import pathlib
- suffix = pathlib.Path(path).suffix
- if suffix and suffix[1:] == "snappy":
- compression = "snappy"
- return compression
|