remote_function.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import inspect
  2. import logging
  3. import os
  4. import uuid
  5. from functools import wraps
  6. from threading import Lock
  7. from typing import Optional
  8. import ray._common.signature
  9. from ray import Language, cross_language
  10. from ray._common import ray_option_utils
  11. from ray._common.ray_option_utils import _warn_if_using_deprecated_placement_group
  12. from ray._common.serialization import pickle_dumps
  13. from ray._private.auto_init_hook import wrap_auto_init
  14. from ray._private.client_mode_hook import (
  15. client_mode_convert_function,
  16. client_mode_should_convert,
  17. )
  18. from ray._private.utils import get_runtime_env_info, parse_runtime_env_for_task_or_actor
  19. from ray._raylet import (
  20. STREAMING_GENERATOR_RETURN,
  21. ObjectRefGenerator,
  22. PythonFunctionDescriptor,
  23. )
  24. from ray.util.annotations import DeveloperAPI, PublicAPI
  25. from ray.util.placement_group import _configure_placement_group_based_on_context
  26. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  27. from ray.util.tracing.tracing_helper import (
  28. _inject_tracing_into_function,
  29. _tracing_task_invocation,
  30. )
  31. logger = logging.getLogger(__name__)
  32. # Hook to call with (fn, resources, strategy) on each local task submission.
  33. _task_launch_hook = None
  34. @PublicAPI
  35. class RemoteFunction:
  36. """A remote function.
  37. This is a decorated function. It can be used to spawn tasks.
  38. Attributes:
  39. _language: The target language.
  40. _function: The original function.
  41. _function_descriptor: The function descriptor. This is not defined
  42. until the remote function is first invoked because that is when the
  43. function is pickled, and the pickled function is used to compute
  44. the function descriptor.
  45. _function_name: The module and function name.
  46. _num_cpus: The default number of CPUs to use for invocations of this
  47. remote function.
  48. _num_gpus: The default number of GPUs to use for invocations of this
  49. remote function.
  50. _memory: The heap memory request in bytes for this task/actor,
  51. rounded down to the nearest integer.
  52. _label_selector: The label requirements on a node for scheduling of the task or actor.
  53. _fallback_strategy: Soft constraints of a list of decorator options to fall back on when scheduling on a node.
  54. _resources: The default custom resource requirements for invocations of
  55. this remote function.
  56. _num_returns: The default number of return values for invocations
  57. of this remote function.
  58. _max_calls: The number of times a worker can execute this function
  59. before exiting.
  60. _max_retries: The number of times this task may be retried
  61. on worker failure.
  62. _retry_exceptions: Whether application-level errors should be retried.
  63. This can be a boolean or a list/tuple of exceptions that should be retried.
  64. _runtime_env: The runtime environment for this task.
  65. _decorator: An optional decorator that should be applied to the remote
  66. function invocation (as opposed to the function execution) before
  67. invoking the function. The decorator must return a function that
  68. takes in two arguments ("args" and "kwargs"). In most cases, it
  69. should call the function that was passed into the decorator and
  70. return the resulting ObjectRefs. For an example, see
  71. "test_decorated_function" in "python/ray/tests/test_basic.py".
  72. _function_signature: The function signature.
  73. _last_export_cluster_and_job: A pair of the last exported cluster
  74. and job to help us to know whether this function was exported.
  75. This is an imperfect mechanism used to determine if we need to
  76. export the remote function again. It is imperfect in the sense that
  77. the actor class definition could be exported multiple times by
  78. different workers.
  79. _scheduling_strategy: Strategy about how to schedule
  80. this remote function.
  81. """
  82. def __init__(
  83. self,
  84. language,
  85. function,
  86. function_descriptor,
  87. task_options,
  88. ):
  89. if inspect.iscoroutinefunction(function):
  90. raise ValueError(
  91. "'async def' should not be used for remote tasks. You can wrap the "
  92. "async function with `asyncio.run(f())`. See more at:"
  93. "https://docs.ray.io/en/latest/ray-core/actors/async_api.html "
  94. )
  95. self._default_options = task_options
  96. # When gpu is used, set the task non-recyclable by default.
  97. # https://github.com/ray-project/ray/issues/29624 for more context.
  98. # Note: Ray task worker process is not being reused when nsight
  99. # profiler is running, as nsight/rocprof-sys generate report
  100. # once the process exit.
  101. num_gpus = self._default_options.get("num_gpus") or 0
  102. if (
  103. num_gpus > 0 and self._default_options.get("max_calls", None) is None
  104. ) or any(
  105. [
  106. s in (self._default_options.get(s) or {})
  107. for s in ["nsight", "rocprof-sys"]
  108. ]
  109. ):
  110. self._default_options["max_calls"] = 1
  111. # TODO(suquark): This is a workaround for class attributes of options.
  112. # They are being used in some other places, mostly tests. Need cleanup later.
  113. # E.g., actors uses "__ray_metadata__" to collect options, we can so something
  114. # similar for remote functions.
  115. for k, v in ray_option_utils.task_options.items():
  116. setattr(self, "_" + k, task_options.get(k, v.default_value))
  117. self._runtime_env = parse_runtime_env_for_task_or_actor(self._runtime_env)
  118. if "runtime_env" in self._default_options:
  119. self._default_options["runtime_env"] = self._runtime_env
  120. # Pre-calculate runtime env info, to avoid re-calculation at `remote`
  121. # invocation. When `remote` call has specified extra `option` field,
  122. # runtime env will be overwritten and re-serialized.
  123. #
  124. # Caveat: To support dynamic runtime envs in
  125. # `func.option(runtime_env={...}).remote()`, we recalculate the serialized
  126. # runtime env info in the `option` call. But it's acceptable since
  127. # pre-calculation here only happens once at `RemoteFunction` initialization.
  128. self._serialized_base_runtime_env_info = ""
  129. if self._runtime_env:
  130. self._serialized_base_runtime_env_info = get_runtime_env_info(
  131. self._runtime_env,
  132. is_job_runtime_env=False,
  133. serialize=True,
  134. )
  135. self._language = language
  136. self._is_generator = inspect.isgeneratorfunction(function)
  137. self._function = function
  138. self._function_signature = None
  139. # Guards trace injection to enforce exactly once semantics
  140. self._inject_lock = Lock()
  141. self._function_name = function.__module__ + "." + function.__name__
  142. self._function_descriptor = function_descriptor
  143. self._is_cross_language = language != Language.PYTHON
  144. self._decorator = getattr(function, "__ray_invocation_decorator__", None)
  145. self._last_export_cluster_and_job = None
  146. self._uuid = uuid.uuid4()
  147. # Override task.remote's signature and docstring
  148. @wraps(function)
  149. def _remote_proxy(*args, **kwargs):
  150. return self._remote(
  151. serialized_runtime_env_info=self._serialized_base_runtime_env_info,
  152. args=args,
  153. kwargs=kwargs,
  154. **self._default_options,
  155. )
  156. self.remote = _remote_proxy
  157. def __call__(self, *args, **kwargs):
  158. raise TypeError(
  159. "Remote functions cannot be called directly. Instead "
  160. f"of running '{self._function_name}()', "
  161. f"try '{self._function_name}.remote()'."
  162. )
  163. # Lock is not picklable
  164. def __getstate__(self):
  165. attrs = self.__dict__.copy()
  166. del attrs["_inject_lock"]
  167. return attrs
  168. def __setstate__(self, state):
  169. self.__dict__.update(state)
  170. self.__dict__["_inject_lock"] = Lock()
  171. def options(self, **task_options):
  172. """Configures and overrides the task invocation parameters.
  173. The arguments are the same as those that can be passed to :obj:`ray.remote`.
  174. Overriding `max_calls` is not supported.
  175. Args:
  176. num_returns: It specifies the number of object refs returned by
  177. the remote function invocation.
  178. num_cpus: The quantity of CPU cores to reserve
  179. for this task or for the lifetime of the actor.
  180. num_gpus: The quantity of GPUs to reserve
  181. for this task or for the lifetime of the actor.
  182. resources (Dict[str, float]): The quantity of various custom resources
  183. to reserve for this task or for the lifetime of the actor.
  184. This is a dictionary mapping strings (resource names) to floats.
  185. label_selector (Dict[str, str]): If specified, the labels required for the node on
  186. which this actor can be scheduled on. The label selector consist of key-value pairs,
  187. where the keys are label names and the value are expressions consisting of an operator
  188. with label values or just a value to indicate equality.
  189. fallback_strategy (List[Dict[str, Any]]): If specified, expresses soft constraints
  190. through a list of decorator options to fall back on when scheduling on a node.
  191. accelerator_type: If specified, requires that the task or actor run
  192. on a node with the specified type of accelerator.
  193. See :ref:`accelerator types <accelerator_types>`.
  194. memory: The heap memory request in bytes for this task/actor,
  195. rounded down to the nearest integer.
  196. object_store_memory: The object store memory request for actors only.
  197. max_calls: This specifies the
  198. maximum number of times that a given worker can execute
  199. the given remote function before it must exit
  200. (this can be used to address memory leaks in third-party
  201. libraries or to reclaim resources that cannot easily be
  202. released, e.g., GPU memory that was acquired by TensorFlow).
  203. By default this is infinite for CPU tasks and 1 for GPU tasks
  204. (to force GPU tasks to release resources after finishing).
  205. max_retries: This specifies the maximum number of times that the remote
  206. function should be rerun when the worker process executing it
  207. crashes unexpectedly. The minimum valid value is 0,
  208. the default is 3 (default), and a value of -1 indicates
  209. infinite retries.
  210. runtime_env (Dict[str, Any]): Specifies the runtime environment for
  211. this actor or task and its children. See
  212. :ref:`runtime-environments` for detailed documentation.
  213. retry_exceptions: This specifies whether application-level errors
  214. should be retried up to max_retries times.
  215. scheduling_strategy: Strategy about how to
  216. schedule a remote function or actor. Possible values are
  217. None: ray will figure out the scheduling strategy to use, it
  218. will either be the PlacementGroupSchedulingStrategy using parent's
  219. placement group if parent has one and has
  220. placement_group_capture_child_tasks set to true,
  221. or "DEFAULT";
  222. "DEFAULT": default hybrid scheduling;
  223. "SPREAD": best effort spread scheduling;
  224. `PlacementGroupSchedulingStrategy`:
  225. placement group based scheduling;
  226. `NodeAffinitySchedulingStrategy`:
  227. node id based affinity scheduling.
  228. enable_task_events: This specifies whether to enable task events for this
  229. task. If set to True, task events such as (task running, finished)
  230. are emitted, and available to Ray Dashboard and State API.
  231. See :ref:`state-api-overview-ref` for more details.
  232. _labels: The key-value labels of a task.
  233. Examples:
  234. .. code-block:: python
  235. @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
  236. def f():
  237. return 1, 2
  238. # Task g will require 2 gpus instead of 1.
  239. g = f.options(num_gpus=2)
  240. """
  241. func_cls = self
  242. # override original options
  243. default_options = self._default_options.copy()
  244. # max_calls could not be used in ".options()", we should remove it before
  245. # merging options from '@ray.remote'.
  246. default_options.pop("max_calls", None)
  247. updated_options = ray_option_utils.update_options(default_options, task_options)
  248. ray_option_utils.validate_task_options(updated_options, in_options=True)
  249. # Only update runtime_env and re-calculate serialized runtime env info when
  250. # ".options()" specifies new runtime_env.
  251. serialized_runtime_env_info = self._serialized_base_runtime_env_info
  252. if "runtime_env" in task_options:
  253. updated_options["runtime_env"] = parse_runtime_env_for_task_or_actor(
  254. updated_options["runtime_env"]
  255. )
  256. # Re-calculate runtime env info based on updated runtime env.
  257. if updated_options["runtime_env"]:
  258. serialized_runtime_env_info = get_runtime_env_info(
  259. updated_options["runtime_env"],
  260. is_job_runtime_env=False,
  261. serialize=True,
  262. )
  263. class FuncWrapper:
  264. def remote(self, *args, **kwargs):
  265. return func_cls._remote(
  266. args=args,
  267. kwargs=kwargs,
  268. serialized_runtime_env_info=serialized_runtime_env_info,
  269. **updated_options,
  270. )
  271. @DeveloperAPI
  272. def bind(self, *args, **kwargs):
  273. """
  274. For Ray DAG building that creates static graph from decorated
  275. class or functions.
  276. """
  277. from ray.dag.function_node import FunctionNode
  278. return FunctionNode(func_cls._function, args, kwargs, updated_options)
  279. return FuncWrapper()
  280. @wrap_auto_init
  281. @_tracing_task_invocation
  282. def _remote(
  283. self,
  284. args=None,
  285. kwargs=None,
  286. serialized_runtime_env_info: Optional[str] = None,
  287. **task_options,
  288. ):
  289. """Submit the remote function for execution."""
  290. # We pop the "max_calls" coming from "@ray.remote" here. We no longer need
  291. # it in "_remote()".
  292. task_options.pop("max_calls", None)
  293. if client_mode_should_convert():
  294. return client_mode_convert_function(self, args, kwargs, **task_options)
  295. worker = ray._private.worker.global_worker
  296. worker.check_connected()
  297. if worker.mode != ray._private.worker.WORKER_MODE:
  298. # Only need to record on the driver side
  299. # since workers are created via tasks or actors
  300. # launched from the driver.
  301. from ray._common.usage import usage_lib
  302. usage_lib.record_library_usage("core")
  303. # We cannot do this when the function is first defined, because we need
  304. # ray.init() to have been called when this executes
  305. with self._inject_lock:
  306. if self._function_signature is None:
  307. self._function = _inject_tracing_into_function(self._function)
  308. self._function_signature = ray._common.signature.extract_signature(
  309. self._function
  310. )
  311. # If this function was not exported in this cluster and job, we need to
  312. # export this function again, because the current GCS doesn't have it.
  313. if (
  314. not self._is_cross_language
  315. and self._last_export_cluster_and_job != worker.current_cluster_and_job
  316. ):
  317. self._function_descriptor = PythonFunctionDescriptor.from_function(
  318. self._function, self._uuid
  319. )
  320. # There is an interesting question here. If the remote function is
  321. # used by a subsequent driver (in the same script), should the
  322. # second driver pickle the function again? If yes, then the remote
  323. # function definition can differ in the second driver (e.g., if
  324. # variables in its closure have changed). We probably want the
  325. # behavior of the remote function in the second driver to be
  326. # independent of whether or not the function was invoked by the
  327. # first driver. This is an argument for repickling the function,
  328. # which we do here.
  329. self._pickled_function = pickle_dumps(
  330. self._function,
  331. f"Could not serialize the function {self._function_descriptor.repr}",
  332. )
  333. self._last_export_cluster_and_job = worker.current_cluster_and_job
  334. worker.function_actor_manager.export(self)
  335. kwargs = {} if kwargs is None else kwargs
  336. args = [] if args is None else args
  337. # fill task required options
  338. for k, v in ray_option_utils.task_options.items():
  339. if k == "max_retries":
  340. # TODO(swang): We need to override max_retries here because the default
  341. # value gets set at Ray import time. Ideally, we should allow setting
  342. # default values from env vars for other options too.
  343. v.default_value = os.environ.get(
  344. "RAY_TASK_MAX_RETRIES", v.default_value
  345. )
  346. v.default_value = int(v.default_value)
  347. task_options[k] = task_options.get(k, v.default_value)
  348. # "max_calls" already takes effects and should not apply again.
  349. # Remove the default value here.
  350. task_options.pop("max_calls", None)
  351. # TODO(suquark): cleanup these fields
  352. name = task_options["name"]
  353. placement_group = task_options["placement_group"]
  354. placement_group_bundle_index = task_options["placement_group_bundle_index"]
  355. placement_group_capture_child_tasks = task_options[
  356. "placement_group_capture_child_tasks"
  357. ]
  358. scheduling_strategy = task_options["scheduling_strategy"]
  359. num_returns = task_options["num_returns"]
  360. if num_returns is None:
  361. if self._is_generator:
  362. num_returns = "streaming"
  363. else:
  364. num_returns = 1
  365. if num_returns == "dynamic":
  366. num_returns = -1
  367. elif num_returns == "streaming":
  368. # TODO(sang): This is a temporary private API.
  369. # Remove it when we migrate to the streaming generator.
  370. num_returns = ray._raylet.STREAMING_GENERATOR_RETURN
  371. generator_backpressure_num_objects = task_options[
  372. "_generator_backpressure_num_objects"
  373. ]
  374. if generator_backpressure_num_objects is None:
  375. generator_backpressure_num_objects = -1
  376. max_retries = task_options["max_retries"]
  377. retry_exceptions = task_options["retry_exceptions"]
  378. if isinstance(retry_exceptions, (list, tuple)):
  379. retry_exception_allowlist = tuple(retry_exceptions)
  380. retry_exceptions = True
  381. else:
  382. retry_exception_allowlist = None
  383. if scheduling_strategy is None or not isinstance(
  384. scheduling_strategy, PlacementGroupSchedulingStrategy
  385. ):
  386. _warn_if_using_deprecated_placement_group(task_options, 4)
  387. resources = ray._common.utils.resources_from_ray_options(task_options)
  388. if scheduling_strategy is None or isinstance(
  389. scheduling_strategy, PlacementGroupSchedulingStrategy
  390. ):
  391. if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy):
  392. placement_group = scheduling_strategy.placement_group
  393. placement_group_bundle_index = (
  394. scheduling_strategy.placement_group_bundle_index
  395. )
  396. placement_group_capture_child_tasks = (
  397. scheduling_strategy.placement_group_capture_child_tasks
  398. )
  399. if placement_group_capture_child_tasks is None:
  400. placement_group_capture_child_tasks = (
  401. worker.should_capture_child_tasks_in_placement_group
  402. )
  403. placement_group = _configure_placement_group_based_on_context(
  404. placement_group_capture_child_tasks,
  405. placement_group_bundle_index,
  406. resources,
  407. {}, # no placement_resources for tasks
  408. self._function_descriptor.function_name,
  409. placement_group=placement_group,
  410. )
  411. if not placement_group.is_empty:
  412. scheduling_strategy = PlacementGroupSchedulingStrategy(
  413. placement_group,
  414. placement_group_bundle_index,
  415. placement_group_capture_child_tasks,
  416. )
  417. else:
  418. scheduling_strategy = "DEFAULT"
  419. if _task_launch_hook:
  420. _task_launch_hook(self._function_descriptor, resources, scheduling_strategy)
  421. # Override enable_task_events to default for actor if not specified (i.e. None)
  422. enable_task_events = task_options.get("enable_task_events")
  423. labels = task_options.get("_labels")
  424. label_selector = task_options.get("label_selector")
  425. fallback_strategy = task_options.get("fallback_strategy")
  426. def invocation(args, kwargs):
  427. if self._is_cross_language:
  428. list_args = cross_language._format_args(worker, args, kwargs)
  429. elif not args and not kwargs and not self._function_signature:
  430. list_args = []
  431. else:
  432. list_args = ray._common.signature.flatten_args(
  433. self._function_signature, args, kwargs
  434. )
  435. if worker.mode == ray._private.worker.LOCAL_MODE:
  436. assert (
  437. not self._is_cross_language
  438. ), "Cross language remote function cannot be executed locally."
  439. object_refs = worker.core_worker.submit_task(
  440. self._language,
  441. self._function_descriptor,
  442. list_args,
  443. name if name is not None else "",
  444. num_returns,
  445. resources,
  446. max_retries,
  447. retry_exceptions,
  448. retry_exception_allowlist,
  449. scheduling_strategy,
  450. worker.debugger_breakpoint,
  451. serialized_runtime_env_info or "{}",
  452. generator_backpressure_num_objects,
  453. enable_task_events,
  454. labels,
  455. label_selector,
  456. fallback_strategy,
  457. )
  458. # Reset worker's debug context from the last "remote" command
  459. # (which applies only to this .remote call).
  460. worker.debugger_breakpoint = b""
  461. if num_returns == STREAMING_GENERATOR_RETURN:
  462. # Streaming generator will return a single ref
  463. # that is for the generator task.
  464. assert len(object_refs) == 1
  465. generator_ref = object_refs[0]
  466. return ObjectRefGenerator(generator_ref, worker)
  467. if len(object_refs) == 1:
  468. return object_refs[0]
  469. elif len(object_refs) > 1:
  470. return object_refs
  471. if self._decorator is not None:
  472. invocation = self._decorator(invocation)
  473. return invocation(args, kwargs)
  474. @DeveloperAPI
  475. def bind(self, *args, **kwargs):
  476. """
  477. For Ray DAG building that creates static graph from decorated
  478. class or functions.
  479. """
  480. from ray.dag.function_node import FunctionNode
  481. return FunctionNode(self._function, args, kwargs, self._default_options)