| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import contextlib
- import functools
- import logging
- import time
- import traceback
- from datetime import datetime
- from typing import (
- Any,
- Callable,
- ContextManager,
- Dict,
- Generator,
- Generic,
- List,
- Optional,
- TypeVar,
- Union,
- )
- import ray
- from ray.train._internal.utils import count_required_parameters
- from ray.train.v2._internal.exceptions import UserExceptionWithTraceback
- from ray.types import ObjectRef
- logger = logging.getLogger(__name__)
- T = TypeVar("T")
- def bundle_to_remote_args(bundle: dict) -> dict:
- """Convert a bundle of resources to Ray actor/task arguments.
- >>> bundle_to_remote_args({"GPU": 1, "memory": 1, "custom": 0.1})
- {'num_cpus': 0, 'num_gpus': 1, 'memory': 1, 'resources': {'custom': 0.1}}
- """
- bundle = bundle.copy()
- args = {
- "num_cpus": bundle.pop("CPU", 0),
- "num_gpus": bundle.pop("GPU", 0),
- "memory": bundle.pop("memory", 0),
- }
- if bundle:
- args["resources"] = bundle
- return args
- def construct_train_func(
- train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
- config: Optional[Dict[str, Any]],
- train_func_context: ContextManager,
- fn_arg_name: Optional[str] = "train_loop_per_worker",
- ) -> Callable[[], T]:
- """Validates and constructs the training function to execute.
- Args:
- train_func: The training function to execute.
- This can either take in no arguments or a ``config`` dict.
- config (Optional[Dict]): Configurations to pass into
- ``train_func``. If None then an empty Dict will be created.
- train_func_context: Context manager for user's `train_func`, which executes
- backend-specific logic before and after the training function.
- fn_arg_name (Optional[str]): The name of training function to use for error
- messages.
- Returns:
- A valid training function.
- Raises:
- ValueError: if the input ``train_func`` is invalid.
- """
- num_required_params = count_required_parameters(train_func)
- if num_required_params > 1:
- err_msg = (
- f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
- f"{num_required_params} required arguments instead."
- )
- raise ValueError(err_msg)
- if num_required_params == 1:
- config = config or {}
- @functools.wraps(train_func)
- def train_fn():
- with train_func_context():
- return train_func(config)
- else: # num_params == 0
- @functools.wraps(train_func)
- def train_fn():
- with train_func_context():
- return train_func()
- return train_fn
- class ObjectRefWrapper(Generic[T]):
- """Thin wrapper around ray.put to manually control dereferencing."""
- def __init__(self, obj: T):
- self._ref = ray.put(obj)
- def get(self) -> T:
- return ray.get(self._ref)
- def date_str(include_ms: bool = False):
- pattern = "%Y-%m-%d_%H-%M-%S"
- if include_ms:
- pattern += ".%f"
- return datetime.today().strftime(pattern)
- def time_monotonic():
- return time.monotonic()
- def _copy_doc(copy_func):
- def wrapped(func):
- func.__doc__ = copy_func.__doc__
- return func
- return wrapped
- def ray_get_safe(
- object_refs: Union[ObjectRef, List[ObjectRef]],
- ) -> Union[Any, List[Any]]:
- """This is a safe version of `ray.get` that raises an exception immediately
- if an input task dies, while the others are still running.
- TODO(ml-team, core-team): This is NOT a long-term solution,
- and we should not maintain this function indefinitely.
- This is a mitigation for a Ray Core bug, and should be removed when
- that is fixed.
- See here: https://github.com/ray-project/ray/issues/47204
- Args:
- object_refs: A single or list of object refs to wait on.
- Returns:
- task_outputs: The outputs of the tasks.
- Raises:
- `RayTaskError`/`RayActorError`: if any of the tasks encounter a runtime error
- or fail due to actor/task death (ex: node failure).
- """
- is_list = isinstance(object_refs, list)
- object_refs = object_refs if is_list else [object_refs]
- unready = object_refs
- task_to_output = {}
- while unready:
- ready, unready = ray.wait(unready, num_returns=1)
- if ready:
- for task, task_output in zip(ready, ray.get(ready)):
- task_to_output[task] = task_output
- assert len(task_to_output) == len(object_refs)
- ordered_outputs = [task_to_output[task] for task in object_refs]
- return ordered_outputs if is_list else ordered_outputs[0]
- @contextlib.contextmanager
- def invoke_context_managers(
- context_managers: List[ContextManager],
- ) -> Generator[None, None, None]:
- """
- Utility to invoke a list of context managers and yield sequentially.
- Args:
- context_managers: List of context managers to invoke.
- """
- with contextlib.ExitStack() as stack:
- for context_manager in context_managers:
- stack.enter_context(context_manager())
- yield
- def get_module_name(obj: object) -> str:
- """Returns the full module name of the given object, including its qualified name.
- Args:
- obj: The object (class, function, etc.) whose module name is required.
- Returns:
- Full module and qualified name as a string.
- """
- return f"{obj.__module__}.{obj.__qualname__}"
- def get_callable_name(fn: Callable) -> str:
- """Returns a readable name for any callable.
- Examples:
- >>> get_callable_name(lambda x: x)
- '<lambda>'
- >>> def foo(a, b): pass
- >>> get_callable_name(foo)
- 'foo'
- >>> from functools import partial
- >>> bar = partial(partial(foo, a=1), b=2)
- >>> get_callable_name(bar)
- 'foo'
- >>> class Dummy:
- ... def __call__(self, a, b): pass
- >>> get_callable_name(Dummy())
- 'Dummy'
- """
- if isinstance(fn, functools.partial):
- return get_callable_name(fn.func)
- # Use __name__ for regular functions and lambdas
- if hasattr(fn, "__name__"):
- return fn.__name__
- # Fallback to the class name for objects that implement __call__
- return fn.__class__.__name__
- def construct_user_exception_with_traceback(
- e: BaseException, exclude_frames: int = 0
- ) -> UserExceptionWithTraceback:
- """Construct a UserExceptionWithTraceback from a base exception.
- Args:
- e: The base exception to construct a UserExceptionWithTraceback from.
- exclude_frames: The number of frames to exclude from the beginnning of
- the traceback.
- Returns:
- A UserExceptionWithTraceback object.
- """
- # TODO(justinvyu): This is brittle and may break if the call stack
- # changes. Figure out a more robust way to exclude these frames.
- exc_traceback_str = traceback.format_exc(
- limit=-(len(traceback.extract_tb(e.__traceback__)) - exclude_frames)
- )
- logger.error(f"Error in training function:\n{exc_traceback_str}")
- return UserExceptionWithTraceback(e, traceback_str=exc_traceback_str)
- def _in_ray_train_worker() -> bool:
- """Check if the current process is a Ray Train V2 worker."""
- from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
- try:
- get_train_fn_utils()
- return True
- except RuntimeError:
- return False
- def requires_train_worker(raise_in_tune_session: bool = False) -> Callable:
- """Check that the caller is a Ray Train worker spawned by Ray Train,
- with access to training function utilities.
- Args:
- raise_in_tune_session: Whether to raise a specific error message if the caller
- is in a Tune session. If True, will raise a DeprecationWarning.
- Returns:
- A decorator that performs this check, which raises an error if the caller
- is not a Ray Train worker.
- """
- def _wrap(fn: Callable) -> Callable:
- @functools.wraps(fn)
- def _wrapped_fn(*args, **kwargs):
- from ray.tune.trainable.trainable_fn_utils import _in_tune_session
- if raise_in_tune_session and _in_tune_session():
- raise DeprecationWarning(
- f"`ray.train.{fn.__name__}` is deprecated when running in a function "
- "passed to Ray Tune. Please use the equivalent `ray.tune` API instead. "
- "See this issue for more context: "
- "https://github.com/ray-project/ray/issues/49454"
- )
- if not _in_ray_train_worker():
- raise RuntimeError(
- f"`{fn.__name__}` cannot be used outside of a Ray Train training function. "
- "You are calling this API from the driver or another non-training process. "
- "These utilities are only available within a function launched by `trainer.fit()`."
- )
- return fn(*args, **kwargs)
- return _wrapped_fn
- return _wrap
|