| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- import abc
- import functools
- import inspect
- import logging
- import os
- import socket
- from typing import (
- Any,
- Callable,
- ContextManager,
- Dict,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
- )
- import ray
- from ray._common.network_utils import find_free_port, is_ipv6
- from ray.actor import ActorHandle
- from ray.air._internal.util import (
- StartTraceback,
- StartTracebackWithWorkerRank,
- )
- from ray.exceptions import RayActorError
- from ray.types import ObjectRef
- T = TypeVar("T")
- logger = logging.getLogger(__name__)
- def check_for_failure(
- remote_values: List[ObjectRef],
- ) -> Tuple[bool, Optional[Exception]]:
- """Check for actor failure when retrieving the remote values.
- Args:
- remote_values: List of object references from Ray actor methods.
- Returns:
- A tuple of (bool, Exception). The bool is
- True if evaluating all object references is successful, False otherwise.
- """
- unfinished = remote_values.copy()
- while len(unfinished) > 0:
- finished, unfinished = ray.wait(unfinished)
- # If a failure occurs the ObjectRef will be marked as finished.
- # Calling ray.get will expose the failure as a RayActorError.
- for object_ref in finished:
- # Everything in finished has either failed or completed
- # successfully.
- try:
- ray.get(object_ref)
- except RayActorError as exc:
- failed_actor_rank = remote_values.index(object_ref)
- logger.info(f"Worker {failed_actor_rank} has failed.")
- return False, exc
- except Exception as exc:
- # Other (e.g. training) errors should be directly raised
- failed_worker_rank = remote_values.index(object_ref)
- raise StartTracebackWithWorkerRank(
- worker_rank=failed_worker_rank
- ) from exc
- return True, None
- def get_address_and_port() -> Tuple[str, int]:
- """Returns the IP address and a free port on this node."""
- addr = ray.util.get_node_ip_address()
- port = find_free_port(socket.AF_INET6 if is_ipv6(addr) else socket.AF_INET)
- return addr, port
- def update_env_vars(env_vars: Dict[str, Any]):
- """Updates the environment variables on this worker process.
- Args:
- env_vars: Environment variables to set.
- """
- sanitized = {k: str(v) for k, v in env_vars.items()}
- os.environ.update(sanitized)
- def count_required_parameters(fn: Callable) -> int:
- """Counts the number of required parameters of a function.
- NOTE: *args counts as 1 required parameter.
- Examples
- --------
- >>> def fn(a, b, /, c, *args, d=1, e=2, **kwargs):
- ... pass
- >>> count_required_parameters(fn)
- 4
- >>> fn = lambda: 1
- >>> count_required_parameters(fn)
- 0
- >>> def fn(config, a, b=1, c=2):
- ... pass
- >>> from functools import partial
- >>> count_required_parameters(partial(fn, a=0))
- 1
- """
- params = inspect.signature(fn).parameters.values()
- positional_param_kinds = {
- inspect.Parameter.POSITIONAL_ONLY,
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.VAR_POSITIONAL,
- }
- return len(
- [
- p
- for p in params
- if p.default == inspect.Parameter.empty and p.kind in positional_param_kinds
- ]
- )
- def construct_train_func(
- train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
- config: Optional[Dict[str, Any]],
- train_func_context: ContextManager,
- fn_arg_name: Optional[str] = "train_func",
- discard_returns: bool = False,
- ) -> Callable[[], T]:
- """Validates and constructs the training function to execute.
- Args:
- train_func: The training function to execute.
- This can either take in no arguments or a ``config`` dict.
- config (Optional[Dict]): Configurations to pass into
- ``train_func``. If None then an empty Dict will be created.
- train_func_context: Context manager for user's `train_func`, which executes
- backend-specific logic before and after the training function.
- fn_arg_name (Optional[str]): The name of training function to use for error
- messages.
- discard_returns: Whether to discard any returns from train_func or not.
- Returns:
- A valid training function.
- Raises:
- ValueError: if the input ``train_func`` is invalid.
- """
- num_required_params = count_required_parameters(train_func)
- if discard_returns:
- # Discard any returns from the function so that
- # BackendExecutor doesn't try to deserialize them.
- # Those returns are inaccesible with AIR anyway.
- @functools.wraps(train_func)
- def discard_return_wrapper(*args, **kwargs):
- try:
- train_func(*args, **kwargs)
- except Exception as e:
- raise StartTraceback from e
- wrapped_train_func = discard_return_wrapper
- else:
- wrapped_train_func = train_func
- if num_required_params > 1:
- err_msg = (
- f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
- f"{num_required_params} required arguments instead."
- )
- raise ValueError(err_msg)
- elif num_required_params == 1:
- config = {} if config is None else config
- @functools.wraps(wrapped_train_func)
- def train_fn():
- try:
- with train_func_context():
- return wrapped_train_func(config)
- except Exception as e:
- raise StartTraceback from e
- else: # num_params == 0
- @functools.wraps(wrapped_train_func)
- def train_fn():
- try:
- with train_func_context():
- return wrapped_train_func()
- except Exception as e:
- raise StartTraceback from e
- return train_fn
- class Singleton(abc.ABCMeta):
- """Singleton Abstract Base Class
- https://stackoverflow.com/questions/33364070/implementing
- -singleton-as-metaclass-but-for-abstract-classes
- """
- _instances = {}
- def __call__(cls, *args, **kwargs):
- if cls not in cls._instances:
- cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
- return cls._instances[cls]
- class ActorWrapper:
- """Wraps an actor to provide same API as using the base class directly."""
- def __init__(self, actor: ActorHandle):
- self.actor = actor
- def __getattr__(self, item):
- # The below will fail if trying to access an attribute (not a method) from the
- # actor.
- actor_method = getattr(self.actor, item)
- return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))
|