| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714 |
- import dis
- import hashlib
- import importlib
- import inspect
- import json
- import logging
- import os
- import sys
- import threading
- import time
- import traceback
- from collections import defaultdict, namedtuple
- from typing import Callable, Optional
- import ray
- import ray._private.profiling as profiling
- from ray import cloudpickle as pickle
- from ray._common.serialization import pickle_dumps
- from ray._private import ray_constants
- from ray._private.inspect_util import (
- is_class_method,
- is_function_or_method,
- is_static_method,
- )
- from ray._private.ray_constants import KV_NAMESPACE_FUNCTION_TABLE
- from ray._private.utils import (
- check_oversized_function,
- ensure_str,
- format_error_message,
- )
- from ray._raylet import (
- WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS,
- JobID,
- PythonFunctionDescriptor,
- )
- from ray.remote_function import RemoteFunction
- from ray.util.tracing.tracing_helper import _inject_tracing_into_class
- FunctionExecutionInfo = namedtuple(
- "FunctionExecutionInfo", ["function", "function_name", "max_calls"]
- )
- ImportedFunctionInfo = namedtuple(
- "ImportedFunctionInfo",
- ["job_id", "function_id", "function_name", "function", "module", "max_calls"],
- )
- """FunctionExecutionInfo: A named tuple storing remote function information."""
- logger = logging.getLogger(__name__)
- def make_function_table_key(key_type: bytes, job_id: JobID, key: Optional[bytes]):
- if key is None:
- return b":".join([key_type, job_id.hex().encode()])
- else:
- return b":".join([key_type, job_id.hex().encode(), key])
- class FunctionActorManager:
- """A class used to export/load remote functions and actors.
- Attributes:
- _worker: The associated worker that this manager related.
- _functions_to_export: The remote functions to export when
- the worker gets connected.
- _actors_to_export: The actors to export when the worker gets
- connected.
- _function_execution_info: The function_id
- and execution_info.
- _num_task_executions: The function
- execution times.
- imported_actor_classes: The set of actor classes keys (format:
- ActorClass:function_id) that are already in GCS.
- """
- def __init__(self, worker):
- self._worker = worker
- self._functions_to_export = []
- self._actors_to_export = []
- # This field is a dictionary that maps function IDs
- # to a FunctionExecutionInfo object. This should only be used on
- # workers that execute remote functions.
- self._function_execution_info = defaultdict(lambda: {})
- self._num_task_executions = defaultdict(lambda: {})
- # A set of all of the actor class keys that have been imported by the
- # import thread. It is safe to convert this worker into an actor of
- # these types.
- self.imported_actor_classes = set()
- self._loaded_actor_classes = {}
- # Deserialize an ActorHandle will call load_actor_class(). If a
- # function closure captured an ActorHandle, the deserialization of the
- # function will be:
- # -> fetch_and_register_remote_function (acquire lock)
- # -> _load_actor_class_from_gcs (acquire lock, too)
- # So, the lock should be a reentrant lock.
- self.lock = threading.RLock()
- self.execution_infos = {}
- # This is the counter to keep track of how many keys have already
- # been exported so that we can find next key quicker.
- self._num_exported = 0
- # This is to protect self._num_exported when doing exporting
- self._export_lock = threading.Lock()
- def increase_task_counter(self, function_descriptor):
- function_id = function_descriptor.function_id
- self._num_task_executions[function_id] += 1
- def get_task_counter(self, function_descriptor):
- function_id = function_descriptor.function_id
- return self._num_task_executions[function_id]
- def compute_collision_identifier(self, function_or_class):
- """The identifier is used to detect excessive duplicate exports.
- The identifier is used to determine when the same function or class is
- exported many times. This can yield false positives.
- Args:
- function_or_class: The function or class to compute an identifier
- for.
- Returns:
- The identifier. Note that different functions or classes can give
- rise to same identifier. However, the same function should
- hopefully always give rise to the same identifier. TODO(rkn):
- verify if this is actually the case. Note that if the
- identifier is incorrect in any way, then we may give warnings
- unnecessarily or fail to give warnings, but the application's
- behavior won't change.
- """
- import io
- string_file = io.StringIO()
- dis.dis(function_or_class, file=string_file, depth=2)
- collision_identifier = function_or_class.__name__ + ":" + string_file.getvalue()
- # Return a hash of the identifier in case it is too large.
- return hashlib.sha256(collision_identifier.encode("utf-8")).digest()
- def load_function_or_class_from_local(self, module_name, function_or_class_name):
- """Try to load a function or class in the module from local."""
- module = importlib.import_module(module_name)
- parts = [part for part in function_or_class_name.split(".") if part]
- object = module
- try:
- for part in parts:
- object = getattr(object, part)
- return object
- except Exception:
- return None
- def export_setup_func(
- self, setup_func: Callable, timeout: Optional[int] = None
- ) -> bytes:
- """Export the setup hook function and return the key."""
- pickled_function = pickle_dumps(
- setup_func,
- "Cannot serialize the worker_process_setup_hook " f"{setup_func.__name__}",
- )
- function_to_run_id = hashlib.shake_128(pickled_function).digest(
- ray_constants.ID_SIZE
- )
- key = make_function_table_key(
- # This value should match with gcs_function_manager.h.
- # Otherwise, it won't be GC'ed.
- WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS.encode(),
- # b"FunctionsToRun",
- self._worker.current_job_id.binary(),
- function_to_run_id,
- )
- check_oversized_function(
- pickled_function, setup_func.__name__, "function", self._worker
- )
- try:
- self._worker.gcs_client.internal_kv_put(
- key,
- pickle.dumps(
- {
- "job_id": self._worker.current_job_id.binary(),
- "function_id": function_to_run_id,
- "function": pickled_function,
- }
- ),
- # overwrite
- True,
- ray_constants.KV_NAMESPACE_FUNCTION_TABLE,
- timeout=timeout,
- )
- except Exception as e:
- logger.exception(
- "Failed to export the setup hook " f"{setup_func.__name__}."
- )
- raise e
- return key
- def export(self, remote_function):
- """Pickle a remote function and export it to redis.
- Args:
- remote_function: the RemoteFunction object.
- """
- if self._worker.load_code_from_local:
- function_descriptor = remote_function._function_descriptor
- module_name, function_name = (
- function_descriptor.module_name,
- function_descriptor.function_name,
- )
- # If the function is dynamic, we still export it to GCS
- # even if load_code_from_local is set True.
- if (
- self.load_function_or_class_from_local(module_name, function_name)
- is not None
- ):
- return
- function = remote_function._function
- pickled_function = remote_function._pickled_function
- check_oversized_function(
- pickled_function,
- remote_function._function_name,
- "remote function",
- self._worker,
- )
- key = make_function_table_key(
- b"RemoteFunction",
- self._worker.current_job_id,
- remote_function._function_descriptor.function_id.binary(),
- )
- if self._worker.gcs_client.internal_kv_exists(key, KV_NAMESPACE_FUNCTION_TABLE):
- return
- val = pickle.dumps(
- {
- "job_id": self._worker.current_job_id.binary(),
- "function_id": remote_function._function_descriptor.function_id.binary(), # noqa: E501
- "function_name": remote_function._function_name,
- "module": function.__module__,
- "function": pickled_function,
- "collision_identifier": self.compute_collision_identifier(function),
- "max_calls": remote_function._max_calls,
- }
- )
- self._worker.gcs_client.internal_kv_put(
- key, val, True, KV_NAMESPACE_FUNCTION_TABLE
- )
- def fetch_registered_method(
- self, key: str, timeout: Optional[int] = None
- ) -> Optional[ImportedFunctionInfo]:
- vals = self._worker.gcs_client.internal_kv_get(
- key, KV_NAMESPACE_FUNCTION_TABLE, timeout=timeout
- )
- if vals is None:
- return None
- else:
- vals = pickle.loads(vals)
- fields = [
- "job_id",
- "function_id",
- "function_name",
- "function",
- "module",
- "max_calls",
- ]
- return ImportedFunctionInfo._make(vals.get(field) for field in fields)
- def fetch_and_register_remote_function(self, key):
- """Import a remote function."""
- remote_function_info = self.fetch_registered_method(key)
- if not remote_function_info:
- return False
- (
- job_id_str,
- function_id_str,
- function_name,
- serialized_function,
- module,
- max_calls,
- ) = remote_function_info
- function_id = ray.FunctionID(function_id_str)
- job_id = ray.JobID(job_id_str)
- max_calls = int(max_calls)
- # This function is called by ImportThread. This operation needs to be
- # atomic. Otherwise, there is race condition. Another thread may use
- # the temporary function above before the real function is ready.
- with self.lock:
- self._num_task_executions[function_id] = 0
- try:
- function = pickle.loads(serialized_function)
- except Exception:
- # If an exception was thrown when the remote function was
- # imported, we record the traceback and notify the scheduler
- # of the failure.
- traceback_str = format_error_message(traceback.format_exc())
- def f(*args, **kwargs):
- raise RuntimeError(
- "The remote function failed to import on the "
- "worker. This may be because needed library "
- "dependencies are not installed in the worker "
- "environment or cannot be found from sys.path "
- f"{sys.path}:\n\n{traceback_str}"
- )
- # Use a placeholder method when function pickled failed
- self._function_execution_info[function_id] = FunctionExecutionInfo(
- function=f, function_name=function_name, max_calls=max_calls
- )
- # Log the error message. Log at DEBUG level to avoid overly
- # spamming the log on import failure. The user gets the error
- # via the RuntimeError message above.
- logger.debug(
- "Failed to unpickle the remote function "
- f"'{function_name}' with "
- f"function ID {function_id.hex()}. "
- f"Job ID:{job_id}."
- f"Traceback:\n{traceback_str}. "
- )
- else:
- # The below line is necessary. Because in the driver process,
- # if the function is defined in the file where the python
- # script was started from, its module is `__main__`.
- # However in the worker process, the `__main__` module is a
- # different module, which is `default_worker.py`
- function.__module__ = module
- self._function_execution_info[function_id] = FunctionExecutionInfo(
- function=function, function_name=function_name, max_calls=max_calls
- )
- return True
- def get_execution_info(self, job_id, function_descriptor):
- """Get the FunctionExecutionInfo of a remote function.
- Args:
- job_id: ID of the job that the function belongs to.
- function_descriptor: The FunctionDescriptor of the function to get.
- Returns:
- A FunctionExecutionInfo object.
- """
- function_id = function_descriptor.function_id
- # If the function has already been loaded,
- # There's no need to load again
- if function_id in self._function_execution_info:
- return self._function_execution_info[function_id]
- if self._worker.load_code_from_local:
- # Load function from local code.
- if not function_descriptor.is_actor_method():
- # If the function is not able to be loaded,
- # try to load it from GCS,
- # even if load_code_from_local is set True
- if self._load_function_from_local(function_descriptor) is True:
- return self._function_execution_info[function_id]
- # Load function from GCS.
- # Wait until the function to be executed has actually been
- # registered on this worker. We will push warnings to the user if
- # we spend too long in this loop.
- # The driver function may not be found in sys.path. Try to load
- # the function from GCS.
- with profiling.profile("wait_for_function"):
- self._wait_for_function(function_descriptor, job_id)
- try:
- function_id = function_descriptor.function_id
- info = self._function_execution_info[function_id]
- except KeyError as e:
- message = (
- "Error occurs in get_execution_info: "
- "job_id: %s, function_descriptor: %s. Message: %s"
- % (job_id, function_descriptor, e)
- )
- raise KeyError(message)
- return info
- def _load_function_from_local(self, function_descriptor):
- assert not function_descriptor.is_actor_method()
- function_id = function_descriptor.function_id
- module_name, function_name = (
- function_descriptor.module_name,
- function_descriptor.function_name,
- )
- object = self.load_function_or_class_from_local(module_name, function_name)
- if object is not None:
- # Directly importing from local may break function with dynamic ray.remote,
- # such as the _start_controller function utilized for the Ray service.
- if isinstance(object, RemoteFunction):
- function = object._function
- else:
- function = object
- self._function_execution_info[function_id] = FunctionExecutionInfo(
- function=function,
- function_name=function_name,
- max_calls=0,
- )
- self._num_task_executions[function_id] = 0
- return True
- else:
- return False
- def _wait_for_function(self, function_descriptor, job_id: str, timeout=10):
- """Wait until the function to be executed is present on this worker.
- This method will simply loop until the import thread has imported the
- relevant function. If we spend too long in this loop, that may indicate
- a problem somewhere and we will push an error message to the user.
- If this worker is an actor, then this will wait until the actor has
- been defined.
- Args:
- function_descriptor : The FunctionDescriptor of the function that
- we want to execute.
- job_id: The ID of the job to push the error message to
- if this times out.
- """
- start_time = time.time()
- # Only send the warning once.
- warning_sent = False
- while True:
- with self.lock:
- if self._worker.actor_id.is_nil():
- if function_descriptor.function_id in self._function_execution_info:
- break
- else:
- key = make_function_table_key(
- b"RemoteFunction",
- job_id,
- function_descriptor.function_id.binary(),
- )
- if self.fetch_and_register_remote_function(key) is True:
- break
- else:
- assert not self._worker.actor_id.is_nil()
- # Actor loading will happen when execute_task is called.
- assert self._worker.actor_id in self._worker.actors
- break
- if time.time() - start_time > timeout:
- warning_message = (
- "This worker was asked to execute a function "
- f"that has not been registered ({function_descriptor}, "
- f"node={self._worker.node_ip_address}, "
- f"worker_id={self._worker.worker_id.hex()}, "
- f"pid={os.getpid()}). You may have to restart Ray."
- )
- if not warning_sent:
- logger.error(warning_message)
- ray._private.utils.push_error_to_driver(
- self._worker,
- ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
- warning_message,
- job_id=job_id,
- )
- warning_sent = True
- time.sleep(0.001)
- def export_actor_class(
- self, Class, actor_creation_function_descriptor, actor_method_names
- ):
- if self._worker.load_code_from_local:
- module_name, class_name = (
- actor_creation_function_descriptor.module_name,
- actor_creation_function_descriptor.class_name,
- )
- # If the class is dynamic, we still export it to GCS
- # even if load_code_from_local is set True.
- if (
- self.load_function_or_class_from_local(module_name, class_name)
- is not None
- ):
- return
- # `current_job_id` shouldn't be NIL, unless:
- # 1) This worker isn't an actor;
- # 2) And a previous task started a background thread, which didn't
- # finish before the task finished, and still uses Ray API
- # after that.
- assert not self._worker.current_job_id.is_nil(), (
- "You might have started a background thread in a non-actor "
- "task, please make sure the thread finishes before the "
- "task finishes."
- )
- job_id = self._worker.current_job_id
- key = make_function_table_key(
- b"ActorClass",
- job_id,
- actor_creation_function_descriptor.function_id.binary(),
- )
- serialized_actor_class = pickle_dumps(
- Class,
- f"Could not serialize the actor class "
- f"{actor_creation_function_descriptor.repr}",
- )
- actor_class_info = {
- "class_name": actor_creation_function_descriptor.class_name.split(".")[-1],
- "module": actor_creation_function_descriptor.module_name,
- "class": serialized_actor_class,
- "job_id": job_id.binary(),
- "collision_identifier": self.compute_collision_identifier(Class),
- "actor_method_names": json.dumps(list(actor_method_names)),
- }
- check_oversized_function(
- actor_class_info["class"],
- actor_class_info["class_name"],
- "actor",
- self._worker,
- )
- self._worker.gcs_client.internal_kv_put(
- key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE
- )
- # TODO(rkn): Currently we allow actor classes to be defined
- # within tasks. I tried to disable this, but it may be necessary
- # because of https://github.com/ray-project/ray/issues/1146.
- def load_actor_class(self, job_id, actor_creation_function_descriptor):
- """Load the actor class.
- Args:
- job_id: job ID of the actor.
- actor_creation_function_descriptor: Function descriptor of
- the actor constructor.
- Returns:
- The actor class.
- """
- function_id = actor_creation_function_descriptor.function_id
- # Check if the actor class already exists in the cache.
- actor_class = self._loaded_actor_classes.get(function_id, None)
- if actor_class is None:
- # Load actor class.
- if self._worker.load_code_from_local:
- # Load actor class from local code first.
- actor_class = self._load_actor_class_from_local(
- actor_creation_function_descriptor
- )
- # If the actor is unable to be loaded
- # from local, try to load it
- # from GCS even if load_code_from_local is set True
- if actor_class is None:
- actor_class = self._load_actor_class_from_gcs(
- job_id, actor_creation_function_descriptor
- )
- else:
- # Load actor class from GCS.
- actor_class = self._load_actor_class_from_gcs(
- job_id, actor_creation_function_descriptor
- )
- # Re-inject tracing into the loaded class. This is necessary because
- # cloudpickle doesn't preserve __signature__ attributes on module-level
- # functions. When a class is pickled and unpickled, user-defined methods
- # are looked up from the module, losing the __signature__ that was set by
- # _inject_tracing_into_class during actor creation. Re-injecting tracing
- # ensures the method signatures include _ray_trace_ctx when tracing is
- # enabled, matching the behavior expected by _tracing_actor_method_invocation.
- _inject_tracing_into_class(actor_class)
- # Save the loaded actor class in cache.
- self._loaded_actor_classes[function_id] = actor_class
- # Generate execution info for the methods of this actor class.
- module_name = actor_creation_function_descriptor.module_name
- actor_class_name = actor_creation_function_descriptor.class_name
- actor_methods = inspect.getmembers(
- actor_class, predicate=is_function_or_method
- )
- for actor_method_name, actor_method in actor_methods:
- # Actor creation function descriptor use a unique function
- # hash to solve actor name conflict. When constructing an
- # actor, the actor creation function descriptor will be the
- # key to find __init__ method execution info. So, here we
- # use actor creation function descriptor as method descriptor
- # for generating __init__ method execution info.
- if actor_method_name == "__init__":
- method_descriptor = actor_creation_function_descriptor
- else:
- method_descriptor = PythonFunctionDescriptor(
- module_name, actor_method_name, actor_class_name
- )
- method_id = method_descriptor.function_id
- executor = self._make_actor_method_executor(
- actor_method_name, actor_method
- )
- self._function_execution_info[method_id] = FunctionExecutionInfo(
- function=executor,
- function_name=actor_method_name,
- max_calls=0,
- )
- self._num_task_executions[method_id] = 0
- self._num_task_executions[function_id] = 0
- return actor_class
- def _load_actor_class_from_local(self, actor_creation_function_descriptor):
- """Load actor class from local code."""
- module_name, class_name = (
- actor_creation_function_descriptor.module_name,
- actor_creation_function_descriptor.class_name,
- )
- object = self.load_function_or_class_from_local(module_name, class_name)
- if object is not None:
- if isinstance(object, ray.actor.ActorClass):
- return object.__ray_metadata__.modified_class
- else:
- return object
- else:
- return None
- def _create_fake_actor_class(
- self, actor_class_name, actor_method_names, traceback_str
- ):
- class TemporaryActor:
- async def __dummy_method(self):
- """Dummy method for this fake actor class to work for async actors.
- Without this method, this temporary actor class fails to initialize
- if the original actor class was async."""
- pass
- def temporary_actor_method(*args, **kwargs):
- raise RuntimeError(
- f"The actor with name {actor_class_name} "
- "failed to import on the worker. This may be because "
- "needed library dependencies are not installed in the "
- f"worker environment:\n\n{traceback_str}"
- )
- for method in actor_method_names:
- setattr(TemporaryActor, method, temporary_actor_method)
- return TemporaryActor
- def _load_actor_class_from_gcs(self, job_id, actor_creation_function_descriptor):
- """Load actor class from GCS."""
- key = make_function_table_key(
- b"ActorClass",
- job_id,
- actor_creation_function_descriptor.function_id.binary(),
- )
- # Fetch raw data from GCS.
- vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
- fields = ["job_id", "class_name", "module", "class", "actor_method_names"]
- if vals is None:
- vals = {}
- else:
- vals = pickle.loads(vals)
- (job_id_str, class_name, module, pickled_class, actor_method_names) = (
- vals.get(field) for field in fields
- )
- class_name = ensure_str(class_name)
- module_name = ensure_str(module)
- job_id = ray.JobID(job_id_str)
- actor_method_names = json.loads(ensure_str(actor_method_names))
- actor_class = None
- try:
- with self.lock:
- actor_class = pickle.loads(pickled_class)
- except Exception:
- logger.debug("Failed to load actor class %s.", class_name)
- # If an exception was thrown when the actor was imported, we record
- # the traceback and notify the scheduler of the failure.
- traceback_str = format_error_message(traceback.format_exc())
- # The actor class failed to be unpickled, create a fake actor
- # class instead (just to produce error messages and to prevent
- # the driver from hanging).
- actor_class = self._create_fake_actor_class(
- class_name, actor_method_names, traceback_str
- )
- # The below line is necessary. Because in the driver process,
- # if the function is defined in the file where the python script
- # was started from, its module is `__main__`.
- # However in the worker process, the `__main__` module is a
- # different module, which is `default_worker.py`
- actor_class.__module__ = module_name
- return actor_class
- def _make_actor_method_executor(self, method_name: str, method):
- """Make an executor that wraps a user-defined actor method.
- The wrapped method updates the worker's internal state and performs any
- necessary checkpointing operations.
- Args:
- method_name: The name of the actor method.
- method: The actor method to wrap. This should be a
- method defined on the actor class and should therefore take an
- instance of the actor as the first argument.
- Returns:
- A function that executes the given actor method on the worker's
- stored instance of the actor. The function also updates the
- worker's internal state to record the executed method.
- """
- def actor_method_executor(__ray_actor, *args, **kwargs):
- # Execute the assigned method.
- is_bound = is_class_method(method) or is_static_method(
- type(__ray_actor), method_name
- )
- if is_bound:
- return method(*args, **kwargs)
- else:
- return method(__ray_actor, *args, **kwargs)
- # Set method_name and method as attributes to the executor closure
- # so we can make decision based on these attributes in task executor.
- # Precisely, asyncio support requires to know whether:
- # - the method is a ray internal method: starts with __ray
- # - the method is a coroutine function: defined by async def
- actor_method_executor.name = method_name
- actor_method_executor.method = method
- return actor_method_executor
|