util.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import contextlib
  2. import functools
  3. import logging
  4. import time
  5. import traceback
  6. from datetime import datetime
  7. from typing import (
  8. Any,
  9. Callable,
  10. ContextManager,
  11. Dict,
  12. Generator,
  13. Generic,
  14. List,
  15. Optional,
  16. TypeVar,
  17. Union,
  18. )
  19. import ray
  20. from ray.train._internal.utils import count_required_parameters
  21. from ray.train.v2._internal.exceptions import UserExceptionWithTraceback
  22. from ray.types import ObjectRef
  23. logger = logging.getLogger(__name__)
  24. T = TypeVar("T")
  25. def bundle_to_remote_args(bundle: dict) -> dict:
  26. """Convert a bundle of resources to Ray actor/task arguments.
  27. >>> bundle_to_remote_args({"GPU": 1, "memory": 1, "custom": 0.1})
  28. {'num_cpus': 0, 'num_gpus': 1, 'memory': 1, 'resources': {'custom': 0.1}}
  29. """
  30. bundle = bundle.copy()
  31. args = {
  32. "num_cpus": bundle.pop("CPU", 0),
  33. "num_gpus": bundle.pop("GPU", 0),
  34. "memory": bundle.pop("memory", 0),
  35. }
  36. if bundle:
  37. args["resources"] = bundle
  38. return args
  39. def construct_train_func(
  40. train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
  41. config: Optional[Dict[str, Any]],
  42. train_func_context: ContextManager,
  43. fn_arg_name: Optional[str] = "train_loop_per_worker",
  44. ) -> Callable[[], T]:
  45. """Validates and constructs the training function to execute.
  46. Args:
  47. train_func: The training function to execute.
  48. This can either take in no arguments or a ``config`` dict.
  49. config (Optional[Dict]): Configurations to pass into
  50. ``train_func``. If None then an empty Dict will be created.
  51. train_func_context: Context manager for user's `train_func`, which executes
  52. backend-specific logic before and after the training function.
  53. fn_arg_name (Optional[str]): The name of training function to use for error
  54. messages.
  55. Returns:
  56. A valid training function.
  57. Raises:
  58. ValueError: if the input ``train_func`` is invalid.
  59. """
  60. num_required_params = count_required_parameters(train_func)
  61. if num_required_params > 1:
  62. err_msg = (
  63. f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
  64. f"{num_required_params} required arguments instead."
  65. )
  66. raise ValueError(err_msg)
  67. if num_required_params == 1:
  68. config = config or {}
  69. @functools.wraps(train_func)
  70. def train_fn():
  71. with train_func_context():
  72. return train_func(config)
  73. else: # num_params == 0
  74. @functools.wraps(train_func)
  75. def train_fn():
  76. with train_func_context():
  77. return train_func()
  78. return train_fn
  79. class ObjectRefWrapper(Generic[T]):
  80. """Thin wrapper around ray.put to manually control dereferencing."""
  81. def __init__(self, obj: T):
  82. self._ref = ray.put(obj)
  83. def get(self) -> T:
  84. return ray.get(self._ref)
  85. def date_str(include_ms: bool = False):
  86. pattern = "%Y-%m-%d_%H-%M-%S"
  87. if include_ms:
  88. pattern += ".%f"
  89. return datetime.today().strftime(pattern)
  90. def time_monotonic():
  91. return time.monotonic()
  92. def _copy_doc(copy_func):
  93. def wrapped(func):
  94. func.__doc__ = copy_func.__doc__
  95. return func
  96. return wrapped
  97. def ray_get_safe(
  98. object_refs: Union[ObjectRef, List[ObjectRef]],
  99. ) -> Union[Any, List[Any]]:
  100. """This is a safe version of `ray.get` that raises an exception immediately
  101. if an input task dies, while the others are still running.
  102. TODO(ml-team, core-team): This is NOT a long-term solution,
  103. and we should not maintain this function indefinitely.
  104. This is a mitigation for a Ray Core bug, and should be removed when
  105. that is fixed.
  106. See here: https://github.com/ray-project/ray/issues/47204
  107. Args:
  108. object_refs: A single or list of object refs to wait on.
  109. Returns:
  110. task_outputs: The outputs of the tasks.
  111. Raises:
  112. `RayTaskError`/`RayActorError`: if any of the tasks encounter a runtime error
  113. or fail due to actor/task death (ex: node failure).
  114. """
  115. is_list = isinstance(object_refs, list)
  116. object_refs = object_refs if is_list else [object_refs]
  117. unready = object_refs
  118. task_to_output = {}
  119. while unready:
  120. ready, unready = ray.wait(unready, num_returns=1)
  121. if ready:
  122. for task, task_output in zip(ready, ray.get(ready)):
  123. task_to_output[task] = task_output
  124. assert len(task_to_output) == len(object_refs)
  125. ordered_outputs = [task_to_output[task] for task in object_refs]
  126. return ordered_outputs if is_list else ordered_outputs[0]
  127. @contextlib.contextmanager
  128. def invoke_context_managers(
  129. context_managers: List[ContextManager],
  130. ) -> Generator[None, None, None]:
  131. """
  132. Utility to invoke a list of context managers and yield sequentially.
  133. Args:
  134. context_managers: List of context managers to invoke.
  135. """
  136. with contextlib.ExitStack() as stack:
  137. for context_manager in context_managers:
  138. stack.enter_context(context_manager())
  139. yield
  140. def get_module_name(obj: object) -> str:
  141. """Returns the full module name of the given object, including its qualified name.
  142. Args:
  143. obj: The object (class, function, etc.) whose module name is required.
  144. Returns:
  145. Full module and qualified name as a string.
  146. """
  147. return f"{obj.__module__}.{obj.__qualname__}"
  148. def get_callable_name(fn: Callable) -> str:
  149. """Returns a readable name for any callable.
  150. Examples:
  151. >>> get_callable_name(lambda x: x)
  152. '<lambda>'
  153. >>> def foo(a, b): pass
  154. >>> get_callable_name(foo)
  155. 'foo'
  156. >>> from functools import partial
  157. >>> bar = partial(partial(foo, a=1), b=2)
  158. >>> get_callable_name(bar)
  159. 'foo'
  160. >>> class Dummy:
  161. ... def __call__(self, a, b): pass
  162. >>> get_callable_name(Dummy())
  163. 'Dummy'
  164. """
  165. if isinstance(fn, functools.partial):
  166. return get_callable_name(fn.func)
  167. # Use __name__ for regular functions and lambdas
  168. if hasattr(fn, "__name__"):
  169. return fn.__name__
  170. # Fallback to the class name for objects that implement __call__
  171. return fn.__class__.__name__
  172. def construct_user_exception_with_traceback(
  173. e: BaseException, exclude_frames: int = 0
  174. ) -> UserExceptionWithTraceback:
  175. """Construct a UserExceptionWithTraceback from a base exception.
  176. Args:
  177. e: The base exception to construct a UserExceptionWithTraceback from.
  178. exclude_frames: The number of frames to exclude from the beginnning of
  179. the traceback.
  180. Returns:
  181. A UserExceptionWithTraceback object.
  182. """
  183. # TODO(justinvyu): This is brittle and may break if the call stack
  184. # changes. Figure out a more robust way to exclude these frames.
  185. exc_traceback_str = traceback.format_exc(
  186. limit=-(len(traceback.extract_tb(e.__traceback__)) - exclude_frames)
  187. )
  188. logger.error(f"Error in training function:\n{exc_traceback_str}")
  189. return UserExceptionWithTraceback(e, traceback_str=exc_traceback_str)
  190. def _in_ray_train_worker() -> bool:
  191. """Check if the current process is a Ray Train V2 worker."""
  192. from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
  193. try:
  194. get_train_fn_utils()
  195. return True
  196. except RuntimeError:
  197. return False
  198. def requires_train_worker(raise_in_tune_session: bool = False) -> Callable:
  199. """Check that the caller is a Ray Train worker spawned by Ray Train,
  200. with access to training function utilities.
  201. Args:
  202. raise_in_tune_session: Whether to raise a specific error message if the caller
  203. is in a Tune session. If True, will raise a DeprecationWarning.
  204. Returns:
  205. A decorator that performs this check, which raises an error if the caller
  206. is not a Ray Train worker.
  207. """
  208. def _wrap(fn: Callable) -> Callable:
  209. @functools.wraps(fn)
  210. def _wrapped_fn(*args, **kwargs):
  211. from ray.tune.trainable.trainable_fn_utils import _in_tune_session
  212. if raise_in_tune_session and _in_tune_session():
  213. raise DeprecationWarning(
  214. f"`ray.train.{fn.__name__}` is deprecated when running in a function "
  215. "passed to Ray Tune. Please use the equivalent `ray.tune` API instead. "
  216. "See this issue for more context: "
  217. "https://github.com/ray-project/ray/issues/49454"
  218. )
  219. if not _in_ray_train_worker():
  220. raise RuntimeError(
  221. f"`{fn.__name__}` cannot be used outside of a Ray Train training function. "
  222. "You are calling this API from the driver or another non-training process. "
  223. "These utilities are only available within a function launched by `trainer.fit()`."
  224. )
  225. return fn(*args, **kwargs)
  226. return _wrapped_fn
  227. return _wrap