utils.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import abc
  2. import functools
  3. import inspect
  4. import logging
  5. import os
  6. import socket
  7. from typing import (
  8. Any,
  9. Callable,
  10. ContextManager,
  11. Dict,
  12. List,
  13. Optional,
  14. Tuple,
  15. TypeVar,
  16. Union,
  17. )
  18. import ray
  19. from ray._common.network_utils import find_free_port, is_ipv6
  20. from ray.actor import ActorHandle
  21. from ray.air._internal.util import (
  22. StartTraceback,
  23. StartTracebackWithWorkerRank,
  24. )
  25. from ray.exceptions import RayActorError
  26. from ray.types import ObjectRef
  27. T = TypeVar("T")
  28. logger = logging.getLogger(__name__)
  29. def check_for_failure(
  30. remote_values: List[ObjectRef],
  31. ) -> Tuple[bool, Optional[Exception]]:
  32. """Check for actor failure when retrieving the remote values.
  33. Args:
  34. remote_values: List of object references from Ray actor methods.
  35. Returns:
  36. A tuple of (bool, Exception). The bool is
  37. True if evaluating all object references is successful, False otherwise.
  38. """
  39. unfinished = remote_values.copy()
  40. while len(unfinished) > 0:
  41. finished, unfinished = ray.wait(unfinished)
  42. # If a failure occurs the ObjectRef will be marked as finished.
  43. # Calling ray.get will expose the failure as a RayActorError.
  44. for object_ref in finished:
  45. # Everything in finished has either failed or completed
  46. # successfully.
  47. try:
  48. ray.get(object_ref)
  49. except RayActorError as exc:
  50. failed_actor_rank = remote_values.index(object_ref)
  51. logger.info(f"Worker {failed_actor_rank} has failed.")
  52. return False, exc
  53. except Exception as exc:
  54. # Other (e.g. training) errors should be directly raised
  55. failed_worker_rank = remote_values.index(object_ref)
  56. raise StartTracebackWithWorkerRank(
  57. worker_rank=failed_worker_rank
  58. ) from exc
  59. return True, None
  60. def get_address_and_port() -> Tuple[str, int]:
  61. """Returns the IP address and a free port on this node."""
  62. addr = ray.util.get_node_ip_address()
  63. port = find_free_port(socket.AF_INET6 if is_ipv6(addr) else socket.AF_INET)
  64. return addr, port
  65. def update_env_vars(env_vars: Dict[str, Any]):
  66. """Updates the environment variables on this worker process.
  67. Args:
  68. env_vars: Environment variables to set.
  69. """
  70. sanitized = {k: str(v) for k, v in env_vars.items()}
  71. os.environ.update(sanitized)
  72. def count_required_parameters(fn: Callable) -> int:
  73. """Counts the number of required parameters of a function.
  74. NOTE: *args counts as 1 required parameter.
  75. Examples
  76. --------
  77. >>> def fn(a, b, /, c, *args, d=1, e=2, **kwargs):
  78. ... pass
  79. >>> count_required_parameters(fn)
  80. 4
  81. >>> fn = lambda: 1
  82. >>> count_required_parameters(fn)
  83. 0
  84. >>> def fn(config, a, b=1, c=2):
  85. ... pass
  86. >>> from functools import partial
  87. >>> count_required_parameters(partial(fn, a=0))
  88. 1
  89. """
  90. params = inspect.signature(fn).parameters.values()
  91. positional_param_kinds = {
  92. inspect.Parameter.POSITIONAL_ONLY,
  93. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  94. inspect.Parameter.VAR_POSITIONAL,
  95. }
  96. return len(
  97. [
  98. p
  99. for p in params
  100. if p.default == inspect.Parameter.empty and p.kind in positional_param_kinds
  101. ]
  102. )
  103. def construct_train_func(
  104. train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
  105. config: Optional[Dict[str, Any]],
  106. train_func_context: ContextManager,
  107. fn_arg_name: Optional[str] = "train_func",
  108. discard_returns: bool = False,
  109. ) -> Callable[[], T]:
  110. """Validates and constructs the training function to execute.
  111. Args:
  112. train_func: The training function to execute.
  113. This can either take in no arguments or a ``config`` dict.
  114. config (Optional[Dict]): Configurations to pass into
  115. ``train_func``. If None then an empty Dict will be created.
  116. train_func_context: Context manager for user's `train_func`, which executes
  117. backend-specific logic before and after the training function.
  118. fn_arg_name (Optional[str]): The name of training function to use for error
  119. messages.
  120. discard_returns: Whether to discard any returns from train_func or not.
  121. Returns:
  122. A valid training function.
  123. Raises:
  124. ValueError: if the input ``train_func`` is invalid.
  125. """
  126. num_required_params = count_required_parameters(train_func)
  127. if discard_returns:
  128. # Discard any returns from the function so that
  129. # BackendExecutor doesn't try to deserialize them.
  130. # Those returns are inaccesible with AIR anyway.
  131. @functools.wraps(train_func)
  132. def discard_return_wrapper(*args, **kwargs):
  133. try:
  134. train_func(*args, **kwargs)
  135. except Exception as e:
  136. raise StartTraceback from e
  137. wrapped_train_func = discard_return_wrapper
  138. else:
  139. wrapped_train_func = train_func
  140. if num_required_params > 1:
  141. err_msg = (
  142. f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
  143. f"{num_required_params} required arguments instead."
  144. )
  145. raise ValueError(err_msg)
  146. elif num_required_params == 1:
  147. config = {} if config is None else config
  148. @functools.wraps(wrapped_train_func)
  149. def train_fn():
  150. try:
  151. with train_func_context():
  152. return wrapped_train_func(config)
  153. except Exception as e:
  154. raise StartTraceback from e
  155. else: # num_params == 0
  156. @functools.wraps(wrapped_train_func)
  157. def train_fn():
  158. try:
  159. with train_func_context():
  160. return wrapped_train_func()
  161. except Exception as e:
  162. raise StartTraceback from e
  163. return train_fn
  164. class Singleton(abc.ABCMeta):
  165. """Singleton Abstract Base Class
  166. https://stackoverflow.com/questions/33364070/implementing
  167. -singleton-as-metaclass-but-for-abstract-classes
  168. """
  169. _instances = {}
  170. def __call__(cls, *args, **kwargs):
  171. if cls not in cls._instances:
  172. cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
  173. return cls._instances[cls]
  174. class ActorWrapper:
  175. """Wraps an actor to provide same API as using the base class directly."""
  176. def __init__(self, actor: ActorHandle):
  177. self.actor = actor
  178. def __getattr__(self, item):
  179. # The below will fail if trying to access an attribute (not a method) from the
  180. # actor.
  181. actor_method = getattr(self.actor, item)
  182. return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))