import functools import importlib import logging import os import pathlib import platform import random import sys import threading import time import urllib.parse import uuid from queue import Empty, Full, Queue from types import ModuleType from typing import ( TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union, overload, ) import numpy as np import pandas as pd # NOTE: pyarrow.fs module needs to be explicitly imported! import pyarrow import pyarrow.fs import ray from ray._common.retry import call_with_retry from ray.data.context import DEFAULT_READ_OP_MIN_NUM_BLOCKS, WARN_PREFIX, DataContext from ray.util.annotations import DeveloperAPI import psutil # TypeVar for preserving function/class signatures through decorators F = TypeVar("F", bound=Callable[..., Any]) if TYPE_CHECKING: import pandas from ray.data._internal.compute import ComputeStrategy from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.planner.exchange.sort_task_spec import SortKey from ray.data.block import ( Block, BlockMetadataWithSchema, Schema, UserDefinedFunction, ) from ray.data.datasource import Datasource, Reader from ray.util.placement_group import PlacementGroup logger = logging.getLogger(__name__) KiB = 1024 # bytes MiB = 1024 * KiB GiB = 1024 * MiB SENTINEL = object() _LOCAL_SCHEME = "local" _EXAMPLE_SCHEME = "example" LazyModule = Union[None, bool, ModuleType] _pyarrow_dataset: LazyModule = None class _OrderedNullSentinel: """Sentinel value that sorts greater than any other non-null value. NOTE: Semantic of this sentinel is closely mirroring that one of ``np.nan`` for the purpose of consistency in handling of ``None``s and ``np.nan``s. """ def __eq__(self, other): return False def __lt__(self, other): # not None < _OrderedNullSentinel # _OrderedNullSentinel < _OrderedNullSentinel # _OrderedNullSentinel < None # _OrderedNullSentinel < np.nan return isinstance(other, _OrderedNullSentinel) or is_null(other) def __le__(self, other): # NOTE: This is just a shortened version of # self < other or self == other return self.__lt__(other) def __gt__(self, other): return not self.__le__(other) def __ge__(self, other): return not self.__lt__(other) def __hash__(self): return id(self) NULL_SENTINEL = _OrderedNullSentinel() def _lazy_import_pyarrow_dataset() -> LazyModule: global _pyarrow_dataset if _pyarrow_dataset is None: try: from pyarrow import dataset as _pyarrow_dataset except ModuleNotFoundError: # If module is not found, set _pyarrow to False so we won't # keep trying to import it on every _lazy_import_pyarrow() call. _pyarrow_dataset = False return _pyarrow_dataset def _check_pyarrow_version(): ray.data._internal.utils.arrow_utils._check_pyarrow_version() def _autodetect_parallelism( parallelism: int, target_max_block_size: Optional[int], ctx: DataContext, datasource_or_legacy_reader: Optional[Union["Datasource", "Reader"]] = None, mem_size: Optional[int] = None, placement_group: Optional["PlacementGroup"] = None, avail_cpus: Optional[int] = None, ) -> Tuple[int, str, Optional[int]]: """Returns parallelism to use and the min safe parallelism to avoid OOMs. This detects parallelism using the following heuristics, applied in order: 1) We start with the default value of 200. This can be overridden by setting the `read_op_min_num_blocks` attribute of :class:`~ray.data.context.DataContext`. 2) Min block size. If the parallelism would make blocks smaller than this threshold, the parallelism is reduced to avoid the overhead of tiny blocks. 3) Max block size. If the parallelism would make blocks larger than this threshold, the parallelism is increased to avoid OOMs during processing. 4) Available CPUs. If the parallelism cannot make use of all the available CPUs in the cluster, the parallelism is increased until it can. Args: parallelism: The user-requested parallelism, or -1 for auto-detection. target_max_block_size: The target max block size to produce. We pass this separately from the DatasetContext because it may be set per-op instead of per-Dataset. ctx: The current Dataset context to use for configs. datasource_or_legacy_reader: The datasource or legacy reader, to be used for data size estimation. mem_size: If passed, then used to compute the parallelism according to target_max_block_size. placement_group: The placement group that this Dataset will execute inside, if any. avail_cpus: Override avail cpus detection (for testing only). Returns: Tuple of detected parallelism (only if -1 was specified), the reason for the detected parallelism (only if -1 was specified), and the estimated inmemory size of the dataset. """ min_safe_parallelism = 1 max_reasonable_parallelism = sys.maxsize if mem_size is None and datasource_or_legacy_reader: mem_size = datasource_or_legacy_reader.estimate_inmemory_data_size() if ( mem_size is not None and not np.isnan(mem_size) and target_max_block_size is not None ): min_safe_parallelism = max(1, int(mem_size / target_max_block_size)) max_reasonable_parallelism = max(1, int(mem_size / ctx.target_min_block_size)) reason = "" if parallelism < 0: if parallelism != -1: raise ValueError("`parallelism` must either be -1 or a positive integer.") if ( ctx.min_parallelism is not None and ctx.min_parallelism != DEFAULT_READ_OP_MIN_NUM_BLOCKS and ctx.read_op_min_num_blocks == DEFAULT_READ_OP_MIN_NUM_BLOCKS ): logger.warning( "``DataContext.min_parallelism`` is deprecated in Ray 2.10. " "Please specify ``DataContext.read_op_min_num_blocks`` instead." ) ctx.read_op_min_num_blocks = ctx.min_parallelism # Start with 2x the number of cores as a baseline, with a min floor. if placement_group is None: placement_group = ray.util.get_current_placement_group() avail_cpus = avail_cpus or _estimate_avail_cpus(placement_group) parallelism = max( min(ctx.read_op_min_num_blocks, max_reasonable_parallelism), min_safe_parallelism, avail_cpus * 2, ) if parallelism == ctx.read_op_min_num_blocks: reason = ( "DataContext.get_current().read_op_min_num_blocks=" f"{ctx.read_op_min_num_blocks}" ) elif parallelism == max_reasonable_parallelism: reason = ( "output blocks of size at least " "DataContext.get_current().target_min_block_size=" f"{ctx.target_min_block_size / MiB} MiB" ) elif parallelism == min_safe_parallelism: # Handle ``None`` (unlimited) gracefully in the log message. if ctx.target_max_block_size is None: display_val = "unlimited" else: display_val = f"{ctx.target_max_block_size / MiB} MiB" reason = ( "output blocks of size at most " "DataContext.get_current().target_max_block_size=" f"{display_val}" ) else: reason = ( "parallelism at least twice the available number " f"of CPUs ({avail_cpus})" ) logger.debug( f"Autodetected parallelism={parallelism} based on " f"estimated_available_cpus={avail_cpus} and " f"estimated_data_size={mem_size}." ) return parallelism, reason, mem_size def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int: """Estimates the available CPU parallelism for this Dataset in the cluster. If we aren't in a placement group, this is trivially the number of CPUs in the cluster. Otherwise, we try to calculate how large the placement group is relative to the size of the cluster. Args: cur_pg: The current placement group, if any. """ cluster_cpus = int(ray.cluster_resources().get("CPU", 1)) cluster_gpus = int(ray.cluster_resources().get("GPU", 0)) # If we're in a placement group, we shouldn't assume the entire cluster's # resources are available for us to use. Estimate an upper bound on what's # reasonable to assume is available for datasets to use. if cur_pg: pg_cpus = 0 for bundle in cur_pg.bundle_specs: # Calculate the proportion of the cluster this placement group "takes up". # Then scale our cluster_cpus proportionally to avoid over-parallelizing # if there are many parallel Tune trials using the cluster. cpu_fraction = bundle.get("CPU", 0) / max(1, cluster_cpus) gpu_fraction = bundle.get("GPU", 0) / max(1, cluster_gpus) max_fraction = max(cpu_fraction, gpu_fraction) # Over-parallelize by up to a factor of 2, but no more than that. It's # preferrable to over-estimate than under-estimate. pg_cpus += 2 * int(max_fraction * cluster_cpus) return min(cluster_cpus, pg_cpus) return cluster_cpus def _estimate_available_parallelism() -> int: """Estimates the available CPU parallelism for this Dataset in the cluster. If we are currently in a placement group, take that into account.""" cur_pg = ray.util.get_current_placement_group() return _estimate_avail_cpus(cur_pg) def _warn_on_high_parallelism(requested_parallelism, num_read_tasks): available_cpu_slots = ray.available_resources().get("CPU", 1) if ( requested_parallelism and num_read_tasks > available_cpu_slots * 4 and num_read_tasks >= 5000 ): logger.warning( f"{WARN_PREFIX} The requested parallelism of {requested_parallelism} " "is more than 4x the number of available CPU slots in the cluster of " f"{available_cpu_slots}. This can " "lead to slowdowns during the data reading phase due to excessive " "task creation. Reduce the parallelism to match with the available " "CPU slots in the cluster, or set parallelism to -1 for Ray Data " "to automatically determine the parallelism. " "You can ignore this message if the cluster is expected to autoscale." ) def _check_import(obj, *, module: str, package: str) -> None: """Check if a required dependency is installed. If `module` can't be imported, this function raises an `ImportError` instructing the user to install `package` from PyPI. Args: obj: The object that has a dependency. module: The name of the module to import. package: The name of the package on PyPI. """ try: importlib.import_module(module) except ImportError: raise ImportError( f"`{obj.__class__.__name__}` depends on '{module}', but Ray Data couldn't " f"import it. Install '{module}' by running `pip install {package}`." ) def _resolve_custom_scheme(path: str) -> str: """Returns the resolved path if the given path follows a Ray-specific custom scheme. Othewise, returns the path unchanged. The supported custom schemes are: "local", "example". """ parsed_uri = urllib.parse.urlparse(path) if parsed_uri.scheme == _LOCAL_SCHEME: path = parsed_uri.netloc + parsed_uri.path elif parsed_uri.scheme == _EXAMPLE_SCHEME: example_data_path = pathlib.Path(__file__).parent.parent / "examples" / "data" path = example_data_path / (parsed_uri.netloc + parsed_uri.path) path = str(path.resolve()) return path def _is_local_scheme(paths: Union[str, List[str]]) -> bool: """Returns True if the given paths are in local scheme. Note: The paths must be in same scheme, i.e. it's invalid and will raise error if paths are mixed with different schemes. """ if isinstance(paths, str): paths = [paths] if isinstance(paths, pathlib.Path): paths = [str(paths)] elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths): raise ValueError("paths must be a path string or a list of path strings.") elif len(paths) == 0: raise ValueError("Must provide at least one path.") num = sum(urllib.parse.urlparse(path).scheme == _LOCAL_SCHEME for path in paths) if num > 0 and num < len(paths): raise ValueError( "The paths must all be local-scheme or not local-scheme, " f"but found mixed {paths}" ) return num == len(paths) def _truncated_repr(obj: Any) -> str: """Utility to return a truncated object representation for error messages.""" msg = str(obj) if len(msg) > 200: msg = msg[:200] + "..." return msg def _insert_doc_at_pattern( obj, *, message: str, pattern: str, insert_after: bool = True, directive: Optional[str] = None, skip_matches: int = 0, ) -> str: if "\n" in message: raise ValueError( "message shouldn't contain any newlines, since this function will insert " f"its own linebreaks when text wrapping: {message}" ) doc = obj.__doc__.strip() if not doc: doc = "" if pattern == "" and insert_after: # Empty pattern + insert_after means that we want to append the message to the # end of the docstring. head = doc tail = "" else: tail = doc i = tail.find(pattern) skip_matches_left = skip_matches while i != -1: if insert_after: # Set offset to the first character after the pattern. offset = i + len(pattern) else: # Set offset to the first character in the matched line. offset = tail[:i].rfind("\n") + 1 head = tail[:offset] tail = tail[offset:] skip_matches_left -= 1 if skip_matches_left <= 0: break elif not insert_after: # Move past the found pattern, since we're skipping it. tail = tail[i - offset + len(pattern) :] i = tail.find(pattern) else: raise ValueError( f"Pattern {pattern} not found after {skip_matches} skips in docstring " f"{doc}" ) # Get indentation of the to-be-inserted text. after_lines = list(filter(bool, tail.splitlines())) if len(after_lines) > 0: lines = after_lines else: lines = list(filter(bool, reversed(head.splitlines()))) # Should always have at least one non-empty line in the docstring. assert len(lines) > 0 indent = " " * (len(lines[0]) - len(lines[0].lstrip())) # Handle directive. message = message.strip("\n") if directive is not None: base = f"{indent}.. {directive}::\n" message = message.replace("\n", "\n" + indent + " " * 4) message = base + indent + " " * 4 + message else: message = indent + message.replace("\n", "\n" + indent) # Add two blank lines before/after message, if necessary. if insert_after ^ (pattern == "\n\n"): # Only two blank lines before message if: # 1. Inserting message after pattern and pattern is not two blank lines. # 2. Inserting message before pattern and pattern is two blank lines. message = "\n\n" + message if (not insert_after) ^ (pattern == "\n\n"): # Only two blank lines after message if: # 1. Inserting message before pattern and pattern is not two blank lines. # 2. Inserting message after pattern and pattern is two blank lines. message = message + "\n\n" # Insert message before/after pattern. parts = [head, message, tail] # Build new docstring. obj.__doc__ = "".join(parts) def _consumption_api( if_more_than_read: bool = False, datasource_metadata: Optional[str] = None, extra_condition: Optional[str] = None, delegate: Optional[str] = None, pattern: str = "Examples:", insert_after: bool = False, ) -> Callable[[F], F]: """Annotate the function with an indication that it's a consumption API, and that it will trigger Dataset execution. """ base = ( " will trigger execution of the lazy transformations performed on " "this dataset." ) if delegate: message = delegate + base elif not if_more_than_read: message = "This operation" + base else: condition = "If this dataset consists of more than a read, " if datasource_metadata is not None: condition += ( f"or if the {datasource_metadata} can't be determined from the " "metadata provided by the datasource, " ) if extra_condition is not None: condition += extra_condition + ", " message = condition + "then this operation" + base def wrap(obj: F) -> F: _insert_doc_at_pattern( obj, message=message, pattern=pattern, insert_after=insert_after, directive="note", ) return obj return wrap @overload def ConsumptionAPI(obj: F) -> F: ... @overload def ConsumptionAPI( *, if_more_than_read: bool = False, datasource_metadata: Optional[str] = None, extra_condition: Optional[str] = None, delegate: Optional[str] = None, ) -> Callable[[F], F]: ... def ConsumptionAPI(*args, **kwargs): """Annotate the function with an indication that it's a consumption API, and that it will trigger Dataset execution. """ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): return _consumption_api()(args[0]) return _consumption_api(*args, **kwargs) def _all_to_all_api() -> Callable[[F], F]: """Annotate the function with an indication that it's a all to all API, and that it is an operation that requires all inputs to be materialized in-memory to execute. """ def wrap(obj: F) -> F: _insert_doc_at_pattern( obj, message=( "This operation requires all inputs to be " "materialized in object store for it to execute." ), pattern="Examples:", insert_after=False, directive="note", ) return obj return wrap @overload def AllToAllAPI(obj: F) -> F: ... def AllToAllAPI(*args, **kwargs): """Annotate the function with an indication that it's a all to all API, and that it is an operation that requires all inputs to be materialized in-memory to execute. """ # This should only be used as a decorator for dataset methods. assert len(args) == 1 and len(kwargs) == 0 and callable(args[0]) return _all_to_all_api()(args[0]) def get_compute_strategy( fn: "UserDefinedFunction", fn_constructor_args: Optional[Iterable[Any]] = None, compute: Optional[Union[str, "ComputeStrategy"]] = None, concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None, ) -> "ComputeStrategy": """Get `ComputeStrategy` based on the function or class, and concurrency information. Args: fn: The function or generator to apply to a record batch, or a class type that can be instantiated to create such a callable. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. compute: Either "tasks" (default) to use Ray Tasks or an :class:`~ray.data.ActorPoolStrategy` to use an autoscaling actor pool. concurrency: The number of Ray workers to use concurrently. Returns: The `ComputeStrategy` for execution. """ # Lazily import these objects to avoid circular imports. from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy from ray.data.block import CallableClass if isinstance(fn, CallableClass): is_callable_class = True else: # TODO(chengsu): disallow object that is not a function. For example, # An object instance of class often indicates a bug in user code. is_callable_class = False if fn_constructor_args is not None: raise ValueError( "``fn_constructor_args`` can only be specified if providing a " f"callable class instance for ``fn``, but got: {fn}." ) if compute is not None: if is_callable_class and ( compute == "tasks" or isinstance(compute, TaskPoolStrategy) ): raise ValueError( f"You specified the callable class {fn} as your UDF with the compute " f"{compute}, but Ray Data can't schedule callable classes with the task " f"pool strategy. To fix this error, pass an ActorPoolStrategy to compute or " f"None to use the default compute strategy." ) elif not is_callable_class and ( compute == "actors" or isinstance(compute, ActorPoolStrategy) ): raise ValueError( f"You specified the function {fn} as your UDF with the compute " f"{compute}, but Ray Data can't schedule regular functions with the actor " f"pool strategy. To fix this error, pass a TaskPoolStrategy to compute or " f"None to use the default compute strategy." ) return compute elif concurrency is not None: # Legacy code path to support `concurrency` argument. logger.warning( "The argument ``concurrency`` is deprecated in Ray 2.51. Please specify " "argument ``compute`` instead. For more information, see " "https://docs.ray.io/en/master/data/transforming-data.html#" "stateful-transforms." ) if isinstance(concurrency, tuple): # Validate tuple length and that all elements are integers if len(concurrency) not in (2, 3) or not all( isinstance(c, int) for c in concurrency ): raise ValueError( "``concurrency`` is expected to be set as a tuple of " f"integers, but got: {concurrency}." ) # Check if function is callable class (common validation) if not is_callable_class: raise ValueError( "``concurrency`` is set as a tuple of integers, but ``fn`` " f"is not a callable class: {fn}. Use ``concurrency=n`` to " "control maximum number of workers to use." ) # Create ActorPoolStrategy based on tuple length if len(concurrency) == 2: return ActorPoolStrategy( min_size=concurrency[0], max_size=concurrency[1] ) else: # len(concurrency) == 3 return ActorPoolStrategy( min_size=concurrency[0], max_size=concurrency[1], initial_size=concurrency[2], ) elif isinstance(concurrency, int): if is_callable_class: return ActorPoolStrategy(size=concurrency) else: return TaskPoolStrategy(size=concurrency) else: raise ValueError( "``concurrency`` is expected to be set as an integer or a " f"tuple of integers, but got: {concurrency}." ) else: if is_callable_class: return ActorPoolStrategy(min_size=1, max_size=None) else: return TaskPoolStrategy() def capfirst(s: str): """Capitalize the first letter of a string Args: s: String to capitalize Returns: Capitalized string """ return s[0].upper() + s[1:] def capitalize(s: str): """Capitalize a string, removing '_' and keeping camelcase. Args: s: String to capitalize Returns: Capitalized string with no underscores. """ return "".join(capfirst(x) for x in s.split("_")) def pandas_df_to_arrow_block( df: "pandas.DataFrame", ) -> Tuple["Block", "BlockMetadataWithSchema"]: from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema block = BlockAccessor.for_block(df).to_arrow() stats = BlockExecStats.builder() return block, BlockMetadataWithSchema.from_block(block, stats=stats.build()) def ndarray_to_block( ndarray: np.ndarray, ctx: DataContext ) -> Tuple["Block", "BlockMetadataWithSchema"]: from ray.data.block import BlockAccessor, BlockExecStats, BlockMetadataWithSchema DataContext._set_current(ctx) stats = BlockExecStats.builder() block = BlockAccessor.batch_to_block({"data": ndarray}) return block, BlockMetadataWithSchema.from_block(block, stats=stats.build()) def get_table_block_metadata_schema( table: Union["pyarrow.Table", "pandas.DataFrame"], ) -> "BlockMetadataWithSchema": from ray.data.block import BlockExecStats, BlockMetadataWithSchema stats = BlockExecStats.builder() return BlockMetadataWithSchema.from_block(table, stats=stats.build()) def unify_block_metadata_schema( block_metadata_with_schemas: List["BlockMetadataWithSchema"], ) -> Optional["Schema"]: """For the input list of BlockMetadata, return a unified schema of the corresponding blocks. If the metadata have no valid schema, returns None. Args: block_metadata_with_schemas: List of BlockMetadata to unify Returns: A unified schema of the input list of schemas, or None if no valid schemas are provided. """ # Some blocks could be empty, in which case we cannot get their schema. # TODO(ekl) validate schema is the same across different blocks. # First check if there are blocks with computed schemas, then unify # valid schemas from all such blocks. schemas_to_unify = [] for m in block_metadata_with_schemas: if m.schema is not None and (m.num_rows is None or m.num_rows > 0): schemas_to_unify.append(m.schema) return unify_schemas_with_validation(schemas_to_unify) def unify_schemas_with_validation( schemas_to_unify: Iterable["Schema"], ) -> Optional["Schema"]: if schemas_to_unify: from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas # Check valid pyarrow installation before attempting schema unification try: import pyarrow as pa except ImportError: pa = None # If the result contains PyArrow schemas, unify them if pa is not None and all(isinstance(s, pa.Schema) for s in schemas_to_unify): return unify_schemas(schemas_to_unify, promote_types=True) # Otherwise, if the resulting schemas are simple types (e.g. int), # return the first schema. return schemas_to_unify[0] return None def unify_ref_bundles_schema( ref_bundles: List["RefBundle"], ) -> Optional["Schema"]: schemas_to_unify = [] for bundle in ref_bundles: if bundle.schema is not None and ( bundle.num_rows() is None or bundle.num_rows() > 0 ): schemas_to_unify.append(bundle.schema) return unify_schemas_with_validation(schemas_to_unify) def find_partition_index( table: Union["pyarrow.Table", "pandas.DataFrame"], desired: Tuple[Union[int, float]], sort_key: "SortKey", ) -> int: """For the given block, find the index where the desired value should be added, to maintain sorted order. We do this by iterating over each column, starting with the primary sort key, and binary searching for the desired value in the column. Each binary search shortens the "range" of indices (represented by ``left`` and ``right``, which are indices of rows) where the desired value could be inserted. Args: table: The block to search in. desired: A single tuple representing the boundary to partition at. ``len(desired)`` must be less than or equal to the number of columns being sorted. sort_key: The sort key to use for sorting, providing the columns to be sorted and their directions. Returns: The index where the desired value should be inserted to maintain sorted order. """ columns = sort_key.get_columns() descending = sort_key.get_descending() left, right = 0, len(table) for i in range(len(desired)): if left == right: return right col_name = columns[i] col_vals = table[col_name].to_numpy()[left:right] desired_val = desired[i] # Handle null values - replace them with sentinel values if desired_val is None: desired_val = NULL_SENTINEL prevleft = left if descending[i] is True: # ``np.searchsorted`` expects the array to be sorted in ascending # order, so we pass ``sorter``, which is an array of integer indices # that sort ``col_vals`` into ascending order. The returned index # is an index into the ascending order of ``col_vals``, so we need # to subtract it from ``len(col_vals)`` to get the index in the # original descending order of ``col_vals``. sorter = np.arange(len(col_vals) - 1, -1, -1) left = prevleft + ( len(col_vals) - np.searchsorted( col_vals, desired_val, side="right", sorter=sorter, ) ) right = prevleft + ( len(col_vals) - np.searchsorted( col_vals, desired_val, side="left", sorter=sorter, ) ) else: left = prevleft + np.searchsorted(col_vals, desired_val, side="left") right = prevleft + np.searchsorted(col_vals, desired_val, side="right") return right if descending[0] is True else left def get_attribute_from_class_name(class_name: str) -> Any: """Get Python attribute from the provided class name. The caller needs to make sure the provided class name includes full module name, and can be imported successfully. """ from importlib import import_module paths = class_name.split(".") if len(paths) < 2: raise ValueError(f"Cannot create object from {class_name}.") module_name = ".".join(paths[:-1]) attribute_name = paths[-1] return getattr(import_module(module_name), attribute_name) T = TypeVar("T") U = TypeVar("U") class _InterruptibleQueue(Queue): """Extension of Python's `queue.Queue` providing ability to get interrupt its method callers in other threads""" INTERRUPTION_CHECK_FREQUENCY_SEC = 0.5 def __init__( self, max_size: int, interrupted_event: Optional[threading.Event] = None ): super().__init__(maxsize=max_size) self._interrupted_event = interrupted_event or threading.Event() def get(self, block=True, timeout=None): if not block or timeout is not None: return super().get(block, timeout) # In case when the call is blocking and no timeout is specified (ie blocking # indefinitely) we apply the following protocol to make it interruptible: # # 1. `Queue.get` is invoked w/ 500ms timeout # 2. `Empty` exception is intercepted (will be raised upon timeout elapsing) # 3. If interrupted flag is set `InterruptedError` is raised # 4. Otherwise, protocol retried (until interrupted or queue # becoming non-empty) while True: if self._interrupted_event.is_set(): raise InterruptedError() try: return super().get( block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC ) except Empty: pass def put(self, item, block=True, timeout=None): if not block or timeout is not None: super().put(item, block, timeout) return # In case when the call is blocking and no timeout is specified (ie blocking # indefinitely) we apply the following protocol to make it interruptible: # # 1. `Queue.pet` is invoked w/ 500ms timeout # 2. `Full` exception is intercepted (will be raised upon timeout elapsing) # 3. If interrupted flag is set `InterruptedError` is raised # 4. Otherwise, protocol retried (until interrupted or queue # becomes non-full) while True: if self._interrupted_event.is_set(): raise InterruptedError() try: super().put( item, block=True, timeout=self.INTERRUPTION_CHECK_FREQUENCY_SEC ) return except Full: pass def make_async_gen( base_iterator: Iterator[T], fn: Callable[[Iterator[T]], Iterator[U]], preserve_ordering: bool, num_workers: int = 1, buffer_size: int = 1, ) -> Generator[U, None, None]: """Returns a generator (iterator) mapping items from the provided iterator applying provided transformation in parallel (using a thread-pool). NOTE: There are some important constraints that needs to be carefully understood before using this method 1. If `preserve_ordering` is True a. This method would unroll input iterator eagerly (irrespective of the speed of resulting generator being consumed). This is necessary as we can not guarantee liveness of the algorithm AND preserving of the original ordering at the same time. b. Resulting ordering of the output will "match" ordering of the input, ie that: iterator = [A1, A2, ... An] output iterator = [map(A1), map(A2), ..., map(An)] 2. If `preserve_ordering` is False a. No more than `num_workers * (queue_buffer_size + 1)` elements will be fetched from the iterator b. Resulting ordering of the output is unspecified (and is non-deterministic) Args: base_iterator: Iterator yielding elements to map fn: Transformation to apply to each element preserve_ordering: Whether ordering has to be preserved num_workers: The number of threads to use in the threadpool (defaults to 1) buffer_size: Number of objects to be buffered in its input/output queues (per queue; defaults to 2). Total number of objects held in memory could be calculated as: num_workers * buffer_size * 2 (input and output) Returns: An generator (iterator) of the elements corresponding to the source elements mapped by provided transformation (while *preserving the ordering*) """ gen_id = random.randint(0, 2**31 - 1) if num_workers < 1: raise ValueError("Size of threadpool must be at least 1.") # Signal handler used to interrupt workers when terminating interrupted_event = threading.Event() # To apply transformations to elements in parallel *and* preserve the ordering # following invariants are established: # - Every worker is handled by standalone thread # - Every worker is assigned an input and an output queue # # And following protocol is implemented: # - Filling worker traverses input iterator round-robin'ing elements across # the input queues (in order!) # - Transforming workers traverse respective input queue in-order: de-queueing # element, applying transformation and enqueuing the result into the output # queue # - Generator (returned from this method) traverses output queues (in the same # order as input queues) dequeues 1 mapped element at a time from each output # queue and yields it # # However, in case when we're preserving the ordering we can not enforce the input # queue size as this could result in deadlocks since transformations could be # producing sequences of arbitrary length. # # Check `test_make_async_gen_varying_seq_length_stress_test` for more context on # this problem. if preserve_ordering: input_queue_buf_size = -1 num_input_queues = num_workers else: input_queue_buf_size = (buffer_size + 1) * num_workers num_input_queues = 1 input_queues = [ _InterruptibleQueue(input_queue_buf_size, interrupted_event) for _ in range(num_input_queues) ] output_queues = [ _InterruptibleQueue(buffer_size, interrupted_event) for _ in range(num_workers) ] # Filling worker def _run_filling_worker(): try: # First, round-robin elements from the iterator into # corresponding input queues (one by one) for idx, item in enumerate(base_iterator): input_queues[idx % num_input_queues].put(item) # NOTE: We have to Enqueue sentinel objects for every transforming # worker: # - In case of preserving order of ``num_queues`` == ``num_workers`` # we will enqueue 1 sentinel per queue # - In case of NOT preserving order all ``num_workers`` sentinels # will be enqueued into a single queue for idx in range(num_workers): input_queues[idx % num_input_queues].put(SENTINEL) except InterruptedError: pass except Exception as e: logger.warning("Caught exception in filling worker!", exc_info=e) # In case of filling worker encountering an exception we have to propagate # it back to the (main) iterating thread. To achieve that we're traversing # output queues *backwards* relative to the order of iterator-thread such # that they are more likely to meet w/in a single iteration. for output_queue in reversed(output_queues): output_queue.put(e) # Transforming worker def _run_transforming_worker(input_queue, output_queue): try: # Create iterator draining the queue, until it receives sentinel # # NOTE: `queue.get` is blocking! input_queue_iter = iter(input_queue.get, SENTINEL) for result in fn(input_queue_iter): # Enqueue result of the transformation output_queue.put(result) # Enqueue sentinel (to signal that transformations are completed) output_queue.put(SENTINEL) except InterruptedError: pass except Exception as e: logger.warning("Caught exception in transforming worker!", exc_info=e) # NOTE: In this case we simply enqueue the exception rather than # interrupting output_queue.put(e) # Start workers threads filling_worker_thread = threading.Thread( target=_run_filling_worker, name=f"map_tp_filling_worker-{gen_id}", daemon=True, ) filling_worker_thread.start() transforming_worker_threads = [ threading.Thread( target=_run_transforming_worker, name=f"map_tp_transforming_worker-{gen_id}-{idx}", args=(input_queues[idx % num_input_queues], output_queues[idx]), daemon=True, ) for idx in range(num_workers) ] for t in transforming_worker_threads: t.start() # Use main thread to yield output batches try: # Keep track of remaining non-empty output queues remaining_output_queues = output_queues while len(remaining_output_queues) > 0: # To provide deterministic ordering of the produced iterator we rely # on the following invariants: # # - Elements from the original iterator are round-robin'd into # input queues (in order) # - Individual workers drain their respective input queues populating # output queues with the results of applying transformation to the # original item (and hence preserving original ordering of the input # queue) # - To yield from the generator output queues are traversed in the same # order and one single element is dequeued (in a blocking way!) at a # time from every individual output queue # empty_queues = [] # At every iteration only remaining non-empty queues # are traversed (to prevent blocking on exhausted queue) for output_queue in remaining_output_queues: # NOTE: This is blocking! item = output_queue.get() if isinstance(item, Exception): raise item if item is SENTINEL: empty_queues.append(output_queue) else: yield item if empty_queues: remaining_output_queues = [ q for q in remaining_output_queues if q not in empty_queues ] finally: # Set flag to interrupt workers (to make sure no dangling # threads holding the objects are left behind) # # NOTE: Interrupted event is set to interrupt the running threads # that might be blocked otherwise waiting on inputs from respective # queues. However, even though we're interrupting the threads we can't # guarantee that threads will be interrupted in time (as this is # dependent on Python's GC finalizer to close the generator by raising # `GeneratorExit`) and hence we can't join on either filling or # transforming workers. interrupted_event.set() class RetryingContextManager: def __init__( self, f: pyarrow.NativeFile, context: DataContext, max_attempts: int = 10, max_backoff_s: int = 32, ): self._f = f self._data_context = context self._max_attempts = max_attempts self._max_backoff_s = max_backoff_s def __repr__(self): return f"<{self.__class__.__name__} fs={self.handler.unwrap()}>" def _retry_operation(self, operation: Callable, description: str): """Execute an operation with retries.""" return call_with_retry( operation, description=description, match=self._data_context.retried_io_errors, max_attempts=self._max_attempts, max_backoff_s=self._max_backoff_s, ) def __enter__(self): return self._retry_operation(self._f.__enter__, "enter file context") def __exit__(self, exc_type, exc_value, traceback): self._retry_operation( lambda: self._f.__exit__(exc_type, exc_value, traceback), "exit file context", ) class RetryingPyFileSystem(pyarrow.fs.PyFileSystem): def __init__(self, handler: "RetryingPyFileSystemHandler"): if not isinstance(handler, RetryingPyFileSystemHandler): assert ValueError("handler must be a RetryingPyFileSystemHandler") super().__init__(handler) @property def retryable_errors(self) -> List[str]: return self.handler._retryable_errors def unwrap(self): return self.handler.unwrap() @classmethod def wrap( cls, fs: "pyarrow.fs.FileSystem", retryable_errors: List[str], max_attempts: int = 10, max_backoff_s: int = 32, ): if isinstance(fs, RetryingPyFileSystem): return fs handler = RetryingPyFileSystemHandler( fs, retryable_errors, max_attempts, max_backoff_s ) return cls(handler) def __reduce__(self): # Serialization of this class breaks for some reason without this return (self.__class__, (self.handler,)) @classmethod def __setstate__(cls, state): # Serialization of this class breaks for some reason without this return cls(*state) class RetryingPyFileSystemHandler(pyarrow.fs.FileSystemHandler): """Wrapper for filesystem objects that adds retry functionality for file operations. This class wraps any filesystem object and adds automatic retries for common file operations that may fail transiently. """ def __init__( self, fs: "pyarrow.fs.FileSystem", retryable_errors: List[str] = tuple(), max_attempts: int = 10, max_backoff_s: int = 32, ): """Initialize the retrying filesystem wrapper. Args: fs: The underlying filesystem to wrap context: DataContext for retry settings max_attempts: Maximum number of retry attempts max_backoff_s: Maximum backoff time in seconds """ assert not isinstance( fs, RetryingPyFileSystem ), "Cannot wrap a RetryingPyFileSystem" self._fs = fs self._retryable_errors = retryable_errors self._max_attempts = max_attempts self._max_backoff_s = max_backoff_s def _retry_operation(self, operation: Callable, description: str): """Execute an operation with retries.""" return call_with_retry( operation, description=description, match=self._retryable_errors, max_attempts=self._max_attempts, max_backoff_s=self._max_backoff_s, ) def unwrap(self): return self._fs def copy_file(self, src: str, dest: str): """Copy a file.""" return self._retry_operation( lambda: self._fs.copy_file(src, dest), f"copy file from {src} to {dest}" ) def create_dir(self, path: str, recursive: bool): """Create a directory and subdirectories.""" return self._retry_operation( lambda: self._fs.create_dir(path, recursive=recursive), f"create directory {path}", ) def delete_dir(self, path: str): """Delete a directory and its contents, recursively.""" return self._retry_operation( lambda: self._fs.delete_dir(path), f"delete directory {path}" ) def delete_dir_contents(self, path: str, missing_dir_ok: bool = False): """Delete a directory's contents, recursively.""" return self._retry_operation( lambda: self._fs.delete_dir_contents(path, missing_dir_ok=missing_dir_ok), f"delete directory contents {path}", ) def delete_file(self, path: str): """Delete a file.""" return self._retry_operation( lambda: self._fs.delete_file(path), f"delete file {path}" ) def delete_root_dir_contents(self): return self._retry_operation( lambda: self._fs.delete_dir_contents("/", accept_root_dir=True), "delete root dir contents", ) def equals(self, other: "pyarrow.fs.FileSystem") -> bool: """Test if this filesystem equals another.""" return self._fs.equals(other) def get_file_info(self, paths: List[str]): """Get info for the given files.""" return self._retry_operation( lambda: self._fs.get_file_info(paths), f"get file info for {paths}", ) def get_file_info_selector(self, selector): return self._retry_operation( lambda: self._fs.get_file_info(selector), f"get file info for {selector}", ) def get_type_name(self): return "RetryingPyFileSystem" def move(self, src: str, dest: str): """Move / rename a file or directory.""" return self._retry_operation( lambda: self._fs.move(src, dest), f"move from {src} to {dest}" ) def normalize_path(self, path: str) -> str: """Normalize filesystem path.""" return self._retry_operation( lambda: self._fs.normalize_path(path), f"normalize path {path}" ) def open_append_stream( self, path: str, metadata=None, ) -> "pyarrow.NativeFile": """Open an output stream for appending. Compression is disabled in this method because it is handled in the PyFileSystem abstract class. """ return self._retry_operation( lambda: self._fs.open_append_stream( path, compression=None, metadata=metadata, ), f"open append stream for {path}", ) def open_input_stream( self, path: str, ) -> "pyarrow.NativeFile": """Open an input stream for sequential reading. Compression is disabled in this method because it is handled in the PyFileSystem abstract class. """ return self._retry_operation( lambda: self._fs.open_input_stream(path, compression=None), f"open input stream for {path}", ) def open_output_stream( self, path: str, metadata=None, ) -> "pyarrow.NativeFile": """Open an output stream for sequential writing." Compression is disabled in this method because it is handled in the PyFileSystem abstract class. """ return self._retry_operation( lambda: self._fs.open_output_stream( path, compression=None, metadata=metadata, ), f"open output stream for {path}", ) def open_input_file(self, path: str) -> "pyarrow.NativeFile": """Open an input file for random access reading.""" return self._retry_operation( lambda: self._fs.open_input_file(path), f"open input file {path}" ) def iterate_with_retry( iterable_factory: Callable[[], Iterable], description: str, *, match: Optional[List[str]] = None, max_attempts: int = 10, max_backoff_s: int = 32, ) -> Any: """Iterate through an iterable with retries. If the iterable raises an exception, this function recreates and re-iterates through the iterable, while skipping the items that have already been yielded. Args: iterable_factory: A no-argument function that creates the iterable. match: A list of strings to match in the exception message. If ``None``, any error is retried. description: An imperitive description of the function being retried. For example, "open the file". max_attempts: The maximum number of attempts to retry. max_backoff_s: The maximum number of seconds to backoff. """ assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}." num_items_yielded = 0 for attempt in range(max_attempts): try: iterable = iterable_factory() for item_index, item in enumerate(iterable): if item_index < num_items_yielded: # Skip items that have already been yielded. continue num_items_yielded += 1 yield item return except Exception as e: is_retryable = match is None or any(pattern in str(e) for pattern in match) if is_retryable and attempt + 1 < max_attempts: # Retry with binary expoential backoff with random jitter. backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random() logger.debug( f"Retrying {attempt+1} attempts to {description} " f"after {backoff} seconds." ) time.sleep(backoff) else: raise e from None def convert_bytes_to_human_readable_str(num_bytes: int) -> str: if num_bytes >= 1e9: num_bytes_str = f"{round(num_bytes / 1e9)}GB" elif num_bytes >= 1e6: num_bytes_str = f"{round(num_bytes / 1e6)}MB" else: num_bytes_str = f"{round(num_bytes / 1e3)}KB" return num_bytes_str def _validate_rows_per_file_args( *, num_rows_per_file: Optional[int] = None, min_rows_per_file: Optional[int] = None, max_rows_per_file: Optional[int] = None, ) -> Tuple[Optional[int], Optional[int]]: """Helper method to validate and handle rows per file arguments. Args: num_rows_per_file: Deprecated parameter for number of rows per file min_rows_per_file: New parameter for minimum rows per file max_rows_per_file: New parameter for maximum rows per file Returns: A tuple of (effective_min_rows_per_file, effective_max_rows_per_file) """ if num_rows_per_file is not None: import warnings warnings.warn( "`num_rows_per_file` is deprecated and will be removed in a future release. " "Use `min_rows_per_file` instead.", DeprecationWarning, stacklevel=3, ) if min_rows_per_file is not None: raise ValueError( "Cannot specify both `num_rows_per_file` and `min_rows_per_file`. " "Use `min_rows_per_file` as `num_rows_per_file` is deprecated." ) min_rows_per_file = num_rows_per_file # Validate max_rows_per_file if max_rows_per_file is not None and max_rows_per_file <= 0: raise ValueError("max_rows_per_file must be a positive integer") # Validate min_rows_per_file if min_rows_per_file is not None and min_rows_per_file <= 0: raise ValueError("min_rows_per_file must be a positive integer") # Validate that max >= min if both are specified if ( min_rows_per_file is not None and max_rows_per_file is not None and min_rows_per_file > max_rows_per_file ): raise ValueError( f"min_rows_per_file ({min_rows_per_file}) cannot be greater than " f"max_rows_per_file ({max_rows_per_file})" ) return min_rows_per_file, max_rows_per_file def is_nan(value) -> bool: """Returns true if provide value is ``np.nan``""" try: return isinstance(value, float) and np.isnan(value) except TypeError: return False def is_null(value: Any) -> bool: """This generalization of ``is_nan`` util qualifying both None and np.nan as null values""" return value is None or is_nan(value) def keys_equal(keys1, keys2): if len(keys1) != len(keys2): return False for k1, k2 in zip(keys1, keys2): if not ((is_nan(k1) and is_nan(k2)) or k1 == k2): return False return True def get_total_obj_store_mem_on_node() -> int: """Return the total object store memory on the current node. This function incurs an RPC. Use it cautiously. """ node_id = ray.get_runtime_context().get_node_id() total_resources_per_node = ray._private.state.total_resources_per_node() assert ( node_id in total_resources_per_node ), f"Expected node '{node_id}' to be in resources: {total_resources_per_node}" return total_resources_per_node[node_id]["object_store_memory"] class MemoryProfiler: """A context manager that polls the USS of the current process. This class approximates the max USS by polling memory and subtracting the amount of shared memory from the resident set size (RSS). It's not a perfect estimate (it can underestimate, e.g., if you use Torch tensors), but estimating the USS is much cheaper than computing the actual USS. .. warning:: This class only works with Linux. If you use it on another platform, `estimate_max_uss` always returns ``None``. Example: .. testcode:: with MemoryProfiler(poll_interval_s=1.0) as profiler: for i in range(10): ... # Your code here print(f"Max USS: {profiler.estimate_max_uss()}") profiler.reset() """ def __init__(self, poll_interval_s: Optional[float]): """ Args: poll_interval_s: The interval to poll the USS of the process. If `None`, this class won't poll the USS. """ self._poll_interval_s = poll_interval_s self._process = psutil.Process(os.getpid()) self._max_uss = None self._max_uss_lock = threading.Lock() self._uss_poll_thread = None self._stop_uss_poll_event = None def __repr__(self): return f"MemoryProfiler(poll_interval_s={self._poll_interval_s})" def __enter__(self): if self._can_estimate_uss() and self._poll_interval_s is not None: ( self._uss_poll_thread, self._stop_uss_poll_event, ) = self._start_uss_poll_thread() return self def __exit__(self, exc_type, exc_val, exc_tb): if self._uss_poll_thread is not None: self._stop_uss_poll_thread() def estimate_max_uss(self) -> Optional[int]: """Get an estimate of the max USS of the current process. Returns: An estimate of the max USS of the process in bytes, or ``None`` if an estimate isn't available. """ if not self._can_estimate_uss(): assert self._max_uss is None return None with self._max_uss_lock: if self._max_uss is None: self._max_uss = self._estimate_uss() else: self._max_uss = max(self._max_uss, self._estimate_uss()) assert self._max_uss is not None return self._max_uss def reset(self): with self._max_uss_lock: self._max_uss = None def _start_uss_poll_thread(self) -> Tuple[threading.Thread, threading.Event]: assert self._poll_interval_s is not None assert self._can_estimate_uss() stop_event = threading.Event() def poll_uss(): while not stop_event.is_set(): with self._max_uss_lock: if self._max_uss is None: self._max_uss = self._estimate_uss() else: self._max_uss = max(self._max_uss, self._estimate_uss()) stop_event.wait(self._poll_interval_s) thread = threading.Thread(target=poll_uss, daemon=True) thread.start() return thread, stop_event def _stop_uss_poll_thread(self): if self._stop_uss_poll_event is not None: self._stop_uss_poll_event.set() self._uss_poll_thread.join() def _estimate_uss(self) -> int: assert self._can_estimate_uss() memory_info = self._process.memory_info() # Estimate the USS (the amount of memory that'd be free if we killed the # process right now) as the difference between the RSS (total physical memory) # and amount of shared physical memory. return memory_info.rss - memory_info.shared @staticmethod @functools.cache def _can_estimate_uss() -> bool: # MacOS and Windows don't have the 'shared' attribute of `memory_info()`. return platform.system() == "Linux" def unzip(data: List[Tuple[Any, ...]]) -> Tuple[List[Any], ...]: """Unzips a list of tuples into a tuple of lists Args: data: A list of tuples to unzip. Returns: A tuple of lists, where each list corresponds to one element of the tuples in the input list. """ return tuple(map(list, zip(*data))) def _sort_df(df: pd.DataFrame) -> pd.DataFrame: """Sort DataFrame by columns and rows, and also handle unhashable types.""" df = df.copy() def to_sortable(x): if isinstance(x, (list, np.ndarray)): return tuple(to_sortable(i) for i in x) if isinstance(x, dict): return tuple(sorted((k, to_sortable(v)) for k, v in x.items())) return x sort_cols = [] temp_cols = [] # Sort by all columns to ensure deterministic order. columns = sorted(df.columns) for col in columns: if df[col].dtype == "object": # Create a temporary column for sorting to handle unhashable types. # Use UUID to avoid collisions with existing column names. temp_col = f"__sort_proxy_{uuid.uuid4().hex}_{col}__" df[temp_col] = df[col].map(to_sortable) sort_cols.append(temp_col) temp_cols.append(temp_col) else: sort_cols.append(col) sorted_df = df.sort_values(sort_cols) if temp_cols: sorted_df = sorted_df.drop(columns=temp_cols) return sorted_df def rows_same(actual: pd.DataFrame, expected: pd.DataFrame) -> bool: """Check if two DataFrames have the same rows. Unlike the built-in pandas equals method, this function ignores indices and the order of rows. This is useful for testing Ray Data because its interface doesn't usually guarantee the order of rows. """ if len(actual) != len(expected): return False if len(actual) == 0: return True pd.testing.assert_frame_equal( _sort_df(actual).reset_index(drop=True), _sort_df(expected).reset_index(drop=True), check_dtype=False, ) return True def merge_resources_to_ray_remote_args( num_cpus: Optional[int], num_gpus: Optional[int], memory: Optional[int], ray_remote_args: Dict[str, Any], ) -> Dict[str, Any]: """Convert the given resources to Ray remote args. Args: num_cpus: The number of CPUs to be added to the Ray remote args. num_gpus: The number of GPUs to be added to the Ray remote args. memory: The memory to be added to the Ray remote args. ray_remote_args: The Ray remote args to be merged. Returns: The converted arguments. """ ray_remote_args = ray_remote_args.copy() if num_cpus is not None: ray_remote_args["num_cpus"] = num_cpus if num_gpus is not None: ray_remote_args["num_gpus"] = num_gpus if memory is not None: ray_remote_args["memory"] = memory return ray_remote_args @DeveloperAPI def infer_compression(path: str) -> Optional[str]: import pyarrow as pa compression = None try: # Try to detect compression codec from path. compression = pa.Codec.detect(path).name except (ValueError, TypeError): # Arrow's compression inference on the file path doesn't work for Snappy, so we double-check ourselves. import pathlib suffix = pathlib.Path(path).suffix if suffix and suffix[1:] == "snappy": compression = "snappy" return compression