function_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. import dis
  2. import hashlib
  3. import importlib
  4. import inspect
  5. import json
  6. import logging
  7. import os
  8. import sys
  9. import threading
  10. import time
  11. import traceback
  12. from collections import defaultdict, namedtuple
  13. from typing import Callable, Optional
  14. import ray
  15. import ray._private.profiling as profiling
  16. from ray import cloudpickle as pickle
  17. from ray._common.serialization import pickle_dumps
  18. from ray._private import ray_constants
  19. from ray._private.inspect_util import (
  20. is_class_method,
  21. is_function_or_method,
  22. is_static_method,
  23. )
  24. from ray._private.ray_constants import KV_NAMESPACE_FUNCTION_TABLE
  25. from ray._private.utils import (
  26. check_oversized_function,
  27. ensure_str,
  28. format_error_message,
  29. )
  30. from ray._raylet import (
  31. WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS,
  32. JobID,
  33. PythonFunctionDescriptor,
  34. )
  35. from ray.remote_function import RemoteFunction
  36. from ray.util.tracing.tracing_helper import _inject_tracing_into_class
  37. FunctionExecutionInfo = namedtuple(
  38. "FunctionExecutionInfo", ["function", "function_name", "max_calls"]
  39. )
  40. ImportedFunctionInfo = namedtuple(
  41. "ImportedFunctionInfo",
  42. ["job_id", "function_id", "function_name", "function", "module", "max_calls"],
  43. )
  44. """FunctionExecutionInfo: A named tuple storing remote function information."""
  45. logger = logging.getLogger(__name__)
  46. def make_function_table_key(key_type: bytes, job_id: JobID, key: Optional[bytes]):
  47. if key is None:
  48. return b":".join([key_type, job_id.hex().encode()])
  49. else:
  50. return b":".join([key_type, job_id.hex().encode(), key])
  51. class FunctionActorManager:
  52. """A class used to export/load remote functions and actors.
  53. Attributes:
  54. _worker: The associated worker that this manager related.
  55. _functions_to_export: The remote functions to export when
  56. the worker gets connected.
  57. _actors_to_export: The actors to export when the worker gets
  58. connected.
  59. _function_execution_info: The function_id
  60. and execution_info.
  61. _num_task_executions: The function
  62. execution times.
  63. imported_actor_classes: The set of actor classes keys (format:
  64. ActorClass:function_id) that are already in GCS.
  65. """
  66. def __init__(self, worker):
  67. self._worker = worker
  68. self._functions_to_export = []
  69. self._actors_to_export = []
  70. # This field is a dictionary that maps function IDs
  71. # to a FunctionExecutionInfo object. This should only be used on
  72. # workers that execute remote functions.
  73. self._function_execution_info = defaultdict(lambda: {})
  74. self._num_task_executions = defaultdict(lambda: {})
  75. # A set of all of the actor class keys that have been imported by the
  76. # import thread. It is safe to convert this worker into an actor of
  77. # these types.
  78. self.imported_actor_classes = set()
  79. self._loaded_actor_classes = {}
  80. # Deserialize an ActorHandle will call load_actor_class(). If a
  81. # function closure captured an ActorHandle, the deserialization of the
  82. # function will be:
  83. # -> fetch_and_register_remote_function (acquire lock)
  84. # -> _load_actor_class_from_gcs (acquire lock, too)
  85. # So, the lock should be a reentrant lock.
  86. self.lock = threading.RLock()
  87. self.execution_infos = {}
  88. # This is the counter to keep track of how many keys have already
  89. # been exported so that we can find next key quicker.
  90. self._num_exported = 0
  91. # This is to protect self._num_exported when doing exporting
  92. self._export_lock = threading.Lock()
  93. def increase_task_counter(self, function_descriptor):
  94. function_id = function_descriptor.function_id
  95. self._num_task_executions[function_id] += 1
  96. def get_task_counter(self, function_descriptor):
  97. function_id = function_descriptor.function_id
  98. return self._num_task_executions[function_id]
  99. def compute_collision_identifier(self, function_or_class):
  100. """The identifier is used to detect excessive duplicate exports.
  101. The identifier is used to determine when the same function or class is
  102. exported many times. This can yield false positives.
  103. Args:
  104. function_or_class: The function or class to compute an identifier
  105. for.
  106. Returns:
  107. The identifier. Note that different functions or classes can give
  108. rise to same identifier. However, the same function should
  109. hopefully always give rise to the same identifier. TODO(rkn):
  110. verify if this is actually the case. Note that if the
  111. identifier is incorrect in any way, then we may give warnings
  112. unnecessarily or fail to give warnings, but the application's
  113. behavior won't change.
  114. """
  115. import io
  116. string_file = io.StringIO()
  117. dis.dis(function_or_class, file=string_file, depth=2)
  118. collision_identifier = function_or_class.__name__ + ":" + string_file.getvalue()
  119. # Return a hash of the identifier in case it is too large.
  120. return hashlib.sha256(collision_identifier.encode("utf-8")).digest()
  121. def load_function_or_class_from_local(self, module_name, function_or_class_name):
  122. """Try to load a function or class in the module from local."""
  123. module = importlib.import_module(module_name)
  124. parts = [part for part in function_or_class_name.split(".") if part]
  125. object = module
  126. try:
  127. for part in parts:
  128. object = getattr(object, part)
  129. return object
  130. except Exception:
  131. return None
  132. def export_setup_func(
  133. self, setup_func: Callable, timeout: Optional[int] = None
  134. ) -> bytes:
  135. """Export the setup hook function and return the key."""
  136. pickled_function = pickle_dumps(
  137. setup_func,
  138. "Cannot serialize the worker_process_setup_hook " f"{setup_func.__name__}",
  139. )
  140. function_to_run_id = hashlib.shake_128(pickled_function).digest(
  141. ray_constants.ID_SIZE
  142. )
  143. key = make_function_table_key(
  144. # This value should match with gcs_function_manager.h.
  145. # Otherwise, it won't be GC'ed.
  146. WORKER_PROCESS_SETUP_HOOK_KEY_NAME_GCS.encode(),
  147. # b"FunctionsToRun",
  148. self._worker.current_job_id.binary(),
  149. function_to_run_id,
  150. )
  151. check_oversized_function(
  152. pickled_function, setup_func.__name__, "function", self._worker
  153. )
  154. try:
  155. self._worker.gcs_client.internal_kv_put(
  156. key,
  157. pickle.dumps(
  158. {
  159. "job_id": self._worker.current_job_id.binary(),
  160. "function_id": function_to_run_id,
  161. "function": pickled_function,
  162. }
  163. ),
  164. # overwrite
  165. True,
  166. ray_constants.KV_NAMESPACE_FUNCTION_TABLE,
  167. timeout=timeout,
  168. )
  169. except Exception as e:
  170. logger.exception(
  171. "Failed to export the setup hook " f"{setup_func.__name__}."
  172. )
  173. raise e
  174. return key
  175. def export(self, remote_function):
  176. """Pickle a remote function and export it to redis.
  177. Args:
  178. remote_function: the RemoteFunction object.
  179. """
  180. if self._worker.load_code_from_local:
  181. function_descriptor = remote_function._function_descriptor
  182. module_name, function_name = (
  183. function_descriptor.module_name,
  184. function_descriptor.function_name,
  185. )
  186. # If the function is dynamic, we still export it to GCS
  187. # even if load_code_from_local is set True.
  188. if (
  189. self.load_function_or_class_from_local(module_name, function_name)
  190. is not None
  191. ):
  192. return
  193. function = remote_function._function
  194. pickled_function = remote_function._pickled_function
  195. check_oversized_function(
  196. pickled_function,
  197. remote_function._function_name,
  198. "remote function",
  199. self._worker,
  200. )
  201. key = make_function_table_key(
  202. b"RemoteFunction",
  203. self._worker.current_job_id,
  204. remote_function._function_descriptor.function_id.binary(),
  205. )
  206. if self._worker.gcs_client.internal_kv_exists(key, KV_NAMESPACE_FUNCTION_TABLE):
  207. return
  208. val = pickle.dumps(
  209. {
  210. "job_id": self._worker.current_job_id.binary(),
  211. "function_id": remote_function._function_descriptor.function_id.binary(), # noqa: E501
  212. "function_name": remote_function._function_name,
  213. "module": function.__module__,
  214. "function": pickled_function,
  215. "collision_identifier": self.compute_collision_identifier(function),
  216. "max_calls": remote_function._max_calls,
  217. }
  218. )
  219. self._worker.gcs_client.internal_kv_put(
  220. key, val, True, KV_NAMESPACE_FUNCTION_TABLE
  221. )
  222. def fetch_registered_method(
  223. self, key: str, timeout: Optional[int] = None
  224. ) -> Optional[ImportedFunctionInfo]:
  225. vals = self._worker.gcs_client.internal_kv_get(
  226. key, KV_NAMESPACE_FUNCTION_TABLE, timeout=timeout
  227. )
  228. if vals is None:
  229. return None
  230. else:
  231. vals = pickle.loads(vals)
  232. fields = [
  233. "job_id",
  234. "function_id",
  235. "function_name",
  236. "function",
  237. "module",
  238. "max_calls",
  239. ]
  240. return ImportedFunctionInfo._make(vals.get(field) for field in fields)
  241. def fetch_and_register_remote_function(self, key):
  242. """Import a remote function."""
  243. remote_function_info = self.fetch_registered_method(key)
  244. if not remote_function_info:
  245. return False
  246. (
  247. job_id_str,
  248. function_id_str,
  249. function_name,
  250. serialized_function,
  251. module,
  252. max_calls,
  253. ) = remote_function_info
  254. function_id = ray.FunctionID(function_id_str)
  255. job_id = ray.JobID(job_id_str)
  256. max_calls = int(max_calls)
  257. # This function is called by ImportThread. This operation needs to be
  258. # atomic. Otherwise, there is race condition. Another thread may use
  259. # the temporary function above before the real function is ready.
  260. with self.lock:
  261. self._num_task_executions[function_id] = 0
  262. try:
  263. function = pickle.loads(serialized_function)
  264. except Exception:
  265. # If an exception was thrown when the remote function was
  266. # imported, we record the traceback and notify the scheduler
  267. # of the failure.
  268. traceback_str = format_error_message(traceback.format_exc())
  269. def f(*args, **kwargs):
  270. raise RuntimeError(
  271. "The remote function failed to import on the "
  272. "worker. This may be because needed library "
  273. "dependencies are not installed in the worker "
  274. "environment or cannot be found from sys.path "
  275. f"{sys.path}:\n\n{traceback_str}"
  276. )
  277. # Use a placeholder method when function pickled failed
  278. self._function_execution_info[function_id] = FunctionExecutionInfo(
  279. function=f, function_name=function_name, max_calls=max_calls
  280. )
  281. # Log the error message. Log at DEBUG level to avoid overly
  282. # spamming the log on import failure. The user gets the error
  283. # via the RuntimeError message above.
  284. logger.debug(
  285. "Failed to unpickle the remote function "
  286. f"'{function_name}' with "
  287. f"function ID {function_id.hex()}. "
  288. f"Job ID:{job_id}."
  289. f"Traceback:\n{traceback_str}. "
  290. )
  291. else:
  292. # The below line is necessary. Because in the driver process,
  293. # if the function is defined in the file where the python
  294. # script was started from, its module is `__main__`.
  295. # However in the worker process, the `__main__` module is a
  296. # different module, which is `default_worker.py`
  297. function.__module__ = module
  298. self._function_execution_info[function_id] = FunctionExecutionInfo(
  299. function=function, function_name=function_name, max_calls=max_calls
  300. )
  301. return True
  302. def get_execution_info(self, job_id, function_descriptor):
  303. """Get the FunctionExecutionInfo of a remote function.
  304. Args:
  305. job_id: ID of the job that the function belongs to.
  306. function_descriptor: The FunctionDescriptor of the function to get.
  307. Returns:
  308. A FunctionExecutionInfo object.
  309. """
  310. function_id = function_descriptor.function_id
  311. # If the function has already been loaded,
  312. # There's no need to load again
  313. if function_id in self._function_execution_info:
  314. return self._function_execution_info[function_id]
  315. if self._worker.load_code_from_local:
  316. # Load function from local code.
  317. if not function_descriptor.is_actor_method():
  318. # If the function is not able to be loaded,
  319. # try to load it from GCS,
  320. # even if load_code_from_local is set True
  321. if self._load_function_from_local(function_descriptor) is True:
  322. return self._function_execution_info[function_id]
  323. # Load function from GCS.
  324. # Wait until the function to be executed has actually been
  325. # registered on this worker. We will push warnings to the user if
  326. # we spend too long in this loop.
  327. # The driver function may not be found in sys.path. Try to load
  328. # the function from GCS.
  329. with profiling.profile("wait_for_function"):
  330. self._wait_for_function(function_descriptor, job_id)
  331. try:
  332. function_id = function_descriptor.function_id
  333. info = self._function_execution_info[function_id]
  334. except KeyError as e:
  335. message = (
  336. "Error occurs in get_execution_info: "
  337. "job_id: %s, function_descriptor: %s. Message: %s"
  338. % (job_id, function_descriptor, e)
  339. )
  340. raise KeyError(message)
  341. return info
  342. def _load_function_from_local(self, function_descriptor):
  343. assert not function_descriptor.is_actor_method()
  344. function_id = function_descriptor.function_id
  345. module_name, function_name = (
  346. function_descriptor.module_name,
  347. function_descriptor.function_name,
  348. )
  349. object = self.load_function_or_class_from_local(module_name, function_name)
  350. if object is not None:
  351. # Directly importing from local may break function with dynamic ray.remote,
  352. # such as the _start_controller function utilized for the Ray service.
  353. if isinstance(object, RemoteFunction):
  354. function = object._function
  355. else:
  356. function = object
  357. self._function_execution_info[function_id] = FunctionExecutionInfo(
  358. function=function,
  359. function_name=function_name,
  360. max_calls=0,
  361. )
  362. self._num_task_executions[function_id] = 0
  363. return True
  364. else:
  365. return False
  366. def _wait_for_function(self, function_descriptor, job_id: str, timeout=10):
  367. """Wait until the function to be executed is present on this worker.
  368. This method will simply loop until the import thread has imported the
  369. relevant function. If we spend too long in this loop, that may indicate
  370. a problem somewhere and we will push an error message to the user.
  371. If this worker is an actor, then this will wait until the actor has
  372. been defined.
  373. Args:
  374. function_descriptor : The FunctionDescriptor of the function that
  375. we want to execute.
  376. job_id: The ID of the job to push the error message to
  377. if this times out.
  378. """
  379. start_time = time.time()
  380. # Only send the warning once.
  381. warning_sent = False
  382. while True:
  383. with self.lock:
  384. if self._worker.actor_id.is_nil():
  385. if function_descriptor.function_id in self._function_execution_info:
  386. break
  387. else:
  388. key = make_function_table_key(
  389. b"RemoteFunction",
  390. job_id,
  391. function_descriptor.function_id.binary(),
  392. )
  393. if self.fetch_and_register_remote_function(key) is True:
  394. break
  395. else:
  396. assert not self._worker.actor_id.is_nil()
  397. # Actor loading will happen when execute_task is called.
  398. assert self._worker.actor_id in self._worker.actors
  399. break
  400. if time.time() - start_time > timeout:
  401. warning_message = (
  402. "This worker was asked to execute a function "
  403. f"that has not been registered ({function_descriptor}, "
  404. f"node={self._worker.node_ip_address}, "
  405. f"worker_id={self._worker.worker_id.hex()}, "
  406. f"pid={os.getpid()}). You may have to restart Ray."
  407. )
  408. if not warning_sent:
  409. logger.error(warning_message)
  410. ray._private.utils.push_error_to_driver(
  411. self._worker,
  412. ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
  413. warning_message,
  414. job_id=job_id,
  415. )
  416. warning_sent = True
  417. time.sleep(0.001)
  418. def export_actor_class(
  419. self, Class, actor_creation_function_descriptor, actor_method_names
  420. ):
  421. if self._worker.load_code_from_local:
  422. module_name, class_name = (
  423. actor_creation_function_descriptor.module_name,
  424. actor_creation_function_descriptor.class_name,
  425. )
  426. # If the class is dynamic, we still export it to GCS
  427. # even if load_code_from_local is set True.
  428. if (
  429. self.load_function_or_class_from_local(module_name, class_name)
  430. is not None
  431. ):
  432. return
  433. # `current_job_id` shouldn't be NIL, unless:
  434. # 1) This worker isn't an actor;
  435. # 2) And a previous task started a background thread, which didn't
  436. # finish before the task finished, and still uses Ray API
  437. # after that.
  438. assert not self._worker.current_job_id.is_nil(), (
  439. "You might have started a background thread in a non-actor "
  440. "task, please make sure the thread finishes before the "
  441. "task finishes."
  442. )
  443. job_id = self._worker.current_job_id
  444. key = make_function_table_key(
  445. b"ActorClass",
  446. job_id,
  447. actor_creation_function_descriptor.function_id.binary(),
  448. )
  449. serialized_actor_class = pickle_dumps(
  450. Class,
  451. f"Could not serialize the actor class "
  452. f"{actor_creation_function_descriptor.repr}",
  453. )
  454. actor_class_info = {
  455. "class_name": actor_creation_function_descriptor.class_name.split(".")[-1],
  456. "module": actor_creation_function_descriptor.module_name,
  457. "class": serialized_actor_class,
  458. "job_id": job_id.binary(),
  459. "collision_identifier": self.compute_collision_identifier(Class),
  460. "actor_method_names": json.dumps(list(actor_method_names)),
  461. }
  462. check_oversized_function(
  463. actor_class_info["class"],
  464. actor_class_info["class_name"],
  465. "actor",
  466. self._worker,
  467. )
  468. self._worker.gcs_client.internal_kv_put(
  469. key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE
  470. )
  471. # TODO(rkn): Currently we allow actor classes to be defined
  472. # within tasks. I tried to disable this, but it may be necessary
  473. # because of https://github.com/ray-project/ray/issues/1146.
  474. def load_actor_class(self, job_id, actor_creation_function_descriptor):
  475. """Load the actor class.
  476. Args:
  477. job_id: job ID of the actor.
  478. actor_creation_function_descriptor: Function descriptor of
  479. the actor constructor.
  480. Returns:
  481. The actor class.
  482. """
  483. function_id = actor_creation_function_descriptor.function_id
  484. # Check if the actor class already exists in the cache.
  485. actor_class = self._loaded_actor_classes.get(function_id, None)
  486. if actor_class is None:
  487. # Load actor class.
  488. if self._worker.load_code_from_local:
  489. # Load actor class from local code first.
  490. actor_class = self._load_actor_class_from_local(
  491. actor_creation_function_descriptor
  492. )
  493. # If the actor is unable to be loaded
  494. # from local, try to load it
  495. # from GCS even if load_code_from_local is set True
  496. if actor_class is None:
  497. actor_class = self._load_actor_class_from_gcs(
  498. job_id, actor_creation_function_descriptor
  499. )
  500. else:
  501. # Load actor class from GCS.
  502. actor_class = self._load_actor_class_from_gcs(
  503. job_id, actor_creation_function_descriptor
  504. )
  505. # Re-inject tracing into the loaded class. This is necessary because
  506. # cloudpickle doesn't preserve __signature__ attributes on module-level
  507. # functions. When a class is pickled and unpickled, user-defined methods
  508. # are looked up from the module, losing the __signature__ that was set by
  509. # _inject_tracing_into_class during actor creation. Re-injecting tracing
  510. # ensures the method signatures include _ray_trace_ctx when tracing is
  511. # enabled, matching the behavior expected by _tracing_actor_method_invocation.
  512. _inject_tracing_into_class(actor_class)
  513. # Save the loaded actor class in cache.
  514. self._loaded_actor_classes[function_id] = actor_class
  515. # Generate execution info for the methods of this actor class.
  516. module_name = actor_creation_function_descriptor.module_name
  517. actor_class_name = actor_creation_function_descriptor.class_name
  518. actor_methods = inspect.getmembers(
  519. actor_class, predicate=is_function_or_method
  520. )
  521. for actor_method_name, actor_method in actor_methods:
  522. # Actor creation function descriptor use a unique function
  523. # hash to solve actor name conflict. When constructing an
  524. # actor, the actor creation function descriptor will be the
  525. # key to find __init__ method execution info. So, here we
  526. # use actor creation function descriptor as method descriptor
  527. # for generating __init__ method execution info.
  528. if actor_method_name == "__init__":
  529. method_descriptor = actor_creation_function_descriptor
  530. else:
  531. method_descriptor = PythonFunctionDescriptor(
  532. module_name, actor_method_name, actor_class_name
  533. )
  534. method_id = method_descriptor.function_id
  535. executor = self._make_actor_method_executor(
  536. actor_method_name, actor_method
  537. )
  538. self._function_execution_info[method_id] = FunctionExecutionInfo(
  539. function=executor,
  540. function_name=actor_method_name,
  541. max_calls=0,
  542. )
  543. self._num_task_executions[method_id] = 0
  544. self._num_task_executions[function_id] = 0
  545. return actor_class
  546. def _load_actor_class_from_local(self, actor_creation_function_descriptor):
  547. """Load actor class from local code."""
  548. module_name, class_name = (
  549. actor_creation_function_descriptor.module_name,
  550. actor_creation_function_descriptor.class_name,
  551. )
  552. object = self.load_function_or_class_from_local(module_name, class_name)
  553. if object is not None:
  554. if isinstance(object, ray.actor.ActorClass):
  555. return object.__ray_metadata__.modified_class
  556. else:
  557. return object
  558. else:
  559. return None
  560. def _create_fake_actor_class(
  561. self, actor_class_name, actor_method_names, traceback_str
  562. ):
  563. class TemporaryActor:
  564. async def __dummy_method(self):
  565. """Dummy method for this fake actor class to work for async actors.
  566. Without this method, this temporary actor class fails to initialize
  567. if the original actor class was async."""
  568. pass
  569. def temporary_actor_method(*args, **kwargs):
  570. raise RuntimeError(
  571. f"The actor with name {actor_class_name} "
  572. "failed to import on the worker. This may be because "
  573. "needed library dependencies are not installed in the "
  574. f"worker environment:\n\n{traceback_str}"
  575. )
  576. for method in actor_method_names:
  577. setattr(TemporaryActor, method, temporary_actor_method)
  578. return TemporaryActor
  579. def _load_actor_class_from_gcs(self, job_id, actor_creation_function_descriptor):
  580. """Load actor class from GCS."""
  581. key = make_function_table_key(
  582. b"ActorClass",
  583. job_id,
  584. actor_creation_function_descriptor.function_id.binary(),
  585. )
  586. # Fetch raw data from GCS.
  587. vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
  588. fields = ["job_id", "class_name", "module", "class", "actor_method_names"]
  589. if vals is None:
  590. vals = {}
  591. else:
  592. vals = pickle.loads(vals)
  593. (job_id_str, class_name, module, pickled_class, actor_method_names) = (
  594. vals.get(field) for field in fields
  595. )
  596. class_name = ensure_str(class_name)
  597. module_name = ensure_str(module)
  598. job_id = ray.JobID(job_id_str)
  599. actor_method_names = json.loads(ensure_str(actor_method_names))
  600. actor_class = None
  601. try:
  602. with self.lock:
  603. actor_class = pickle.loads(pickled_class)
  604. except Exception:
  605. logger.debug("Failed to load actor class %s.", class_name)
  606. # If an exception was thrown when the actor was imported, we record
  607. # the traceback and notify the scheduler of the failure.
  608. traceback_str = format_error_message(traceback.format_exc())
  609. # The actor class failed to be unpickled, create a fake actor
  610. # class instead (just to produce error messages and to prevent
  611. # the driver from hanging).
  612. actor_class = self._create_fake_actor_class(
  613. class_name, actor_method_names, traceback_str
  614. )
  615. # The below line is necessary. Because in the driver process,
  616. # if the function is defined in the file where the python script
  617. # was started from, its module is `__main__`.
  618. # However in the worker process, the `__main__` module is a
  619. # different module, which is `default_worker.py`
  620. actor_class.__module__ = module_name
  621. return actor_class
  622. def _make_actor_method_executor(self, method_name: str, method):
  623. """Make an executor that wraps a user-defined actor method.
  624. The wrapped method updates the worker's internal state and performs any
  625. necessary checkpointing operations.
  626. Args:
  627. method_name: The name of the actor method.
  628. method: The actor method to wrap. This should be a
  629. method defined on the actor class and should therefore take an
  630. instance of the actor as the first argument.
  631. Returns:
  632. A function that executes the given actor method on the worker's
  633. stored instance of the actor. The function also updates the
  634. worker's internal state to record the executed method.
  635. """
  636. def actor_method_executor(__ray_actor, *args, **kwargs):
  637. # Execute the assigned method.
  638. is_bound = is_class_method(method) or is_static_method(
  639. type(__ray_actor), method_name
  640. )
  641. if is_bound:
  642. return method(*args, **kwargs)
  643. else:
  644. return method(__ray_actor, *args, **kwargs)
  645. # Set method_name and method as attributes to the executor closure
  646. # so we can make decision based on these attributes in task executor.
  647. # Precisely, asyncio support requires to know whether:
  648. # - the method is a ray internal method: starts with __ray
  649. # - the method is a coroutine function: defined by async def
  650. actor_method_executor.name = method_name
  651. actor_method_executor.method = method
  652. return actor_method_executor