serialization.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. import logging
  2. import threading
  3. import traceback
  4. import warnings
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
  6. if TYPE_CHECKING:
  7. import torch
  8. import google.protobuf.message
  9. import ray._private.utils
  10. import ray.cloudpickle as pickle
  11. import ray.exceptions
  12. from ray._private import (
  13. ray_constants,
  14. tensor_serialization_utils,
  15. )
  16. from ray._raylet import (
  17. DynamicObjectRefGenerator,
  18. MessagePackSerializedObject,
  19. MessagePackSerializer,
  20. Pickle5SerializedObject,
  21. Pickle5Writer,
  22. RawSerializedObject,
  23. SerializedRayObject,
  24. split_buffer,
  25. unpack_pickle5_buffers,
  26. )
  27. from ray.core.generated.common_pb2 import ErrorType, RayErrorInfo
  28. from ray.exceptions import (
  29. ActorDiedError,
  30. ActorPlacementGroupRemoved,
  31. ActorUnavailableError,
  32. ActorUnschedulableError,
  33. LocalRayletDiedError,
  34. NodeDiedError,
  35. ObjectFetchTimedOutError,
  36. ObjectFreedError,
  37. ObjectLostError,
  38. ObjectReconstructionFailedError,
  39. ObjectRefStreamEndOfStreamError,
  40. OutOfDiskError,
  41. OutOfMemoryError,
  42. OwnerDiedError,
  43. PlasmaObjectNotAvailable,
  44. RayError,
  45. RaySystemError,
  46. RayTaskError,
  47. ReferenceCountingAssertionError,
  48. RuntimeEnvSetupError,
  49. TaskCancelledError,
  50. TaskPlacementGroupRemoved,
  51. TaskUnschedulableError,
  52. WorkerCrashedError,
  53. )
  54. from ray.experimental.compiled_dag_ref import CompiledDAGRef
  55. from ray.util import serialization_addons
  56. logger = logging.getLogger(__name__)
  57. ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION = ray_constants.env_bool(
  58. "RAY_allow_out_of_band_object_ref_serialization", True
  59. )
  60. class DeserializationError(Exception):
  61. pass
  62. def _object_ref_deserializer(
  63. binary, call_site, owner_address, object_status, tensor_transport
  64. ):
  65. # NOTE(suquark): This function should be a global function so
  66. # cloudpickle can access it directly. Otherwise cloudpickle
  67. # has to dump the whole function definition, which is inefficient.
  68. # NOTE(swang): Must deserialize the object first before asking
  69. # the core worker to resolve the value. This is to make sure
  70. # that the ref count for the ObjectRef is greater than 0 by the
  71. # time the core worker resolves the value of the object.
  72. obj_ref = ray.ObjectRef(
  73. binary, owner_address, call_site, tensor_transport=tensor_transport
  74. )
  75. # TODO(edoakes): we should be able to just capture a reference
  76. # to 'self' here instead, but this function is itself pickled
  77. # somewhere, which causes an error.
  78. if owner_address:
  79. worker = ray._private.worker.global_worker
  80. worker.check_connected()
  81. context = worker.get_serialization_context()
  82. outer_id = context.get_outer_object_ref()
  83. # outer_id is None in the case that this ObjectRef was closed
  84. # over in a function or pickled directly using pickle.dumps().
  85. if outer_id is None:
  86. outer_id = ray.ObjectRef.nil()
  87. worker.core_worker.deserialize_and_register_object_ref(
  88. obj_ref.binary(), outer_id, owner_address, object_status
  89. )
  90. return obj_ref
  91. def _gpu_object_ref_deserializer(
  92. binary,
  93. call_site,
  94. owner_address,
  95. object_status,
  96. tensor_transport,
  97. gpu_object_meta,
  98. ):
  99. """
  100. Deserialize a GPU object ref. When the GPU object ref is deserialized,
  101. it firstly deserialize the normal object ref, and then add metadata of
  102. the GPU object to the GPU object manager, which will be used to fetch
  103. the GPU object later.
  104. Args:
  105. binary: The binary data of the object ref.
  106. call_site: The call site of the object ref.
  107. owner_address: The owner address of the object ref.
  108. object_status: The object status of the object ref.
  109. tensor_transport: The tensor transport value of the GPU object ref.
  110. gpu_object_meta: The GPU object metadata. This is used to fetch the GPU object later.
  111. Returns:
  112. The deserialized GPU object ref.
  113. """
  114. obj_ref = _object_ref_deserializer(
  115. binary, call_site, owner_address, object_status, tensor_transport
  116. )
  117. gpu_object_manager = ray._private.worker.global_worker.gpu_object_manager
  118. gpu_object_manager.set_gpu_object_metadata(obj_ref.hex(), gpu_object_meta)
  119. return obj_ref
  120. def _actor_handle_deserializer(serialized_obj, weak_ref):
  121. # If this actor handle was stored in another object, then tell the
  122. # core worker.
  123. context = ray._private.worker.global_worker.get_serialization_context()
  124. outer_id = context.get_outer_object_ref()
  125. return ray.actor.ActorHandle._deserialization_helper(
  126. serialized_obj, weak_ref, outer_id
  127. )
  128. class SerializationContext:
  129. """Initialize the serialization library.
  130. This defines a custom serializer for object refs and also tells ray to
  131. serialize several exception classes that we define for error handling.
  132. """
  133. def __init__(self, worker):
  134. self.worker = worker
  135. self._thread_local = threading.local()
  136. # This flag is to mark whether the custom serializer for torch.Tensor has
  137. # been registered. If the method is decorated with
  138. # `@ray.method(tensor_transport="xxx")`, it will use external transport
  139. # (e.g. gloo, nccl, etc.) for tensor communication between actors,
  140. # instead of the normal serialize -> object store -> deserialize codepath.
  141. self._torch_custom_serializer_registered = False
  142. # Enable zero-copy serialization of tensors if the environment variable is set.
  143. self._zero_copy_tensors_enabled = (
  144. ray_constants.RAY_ENABLE_ZERO_COPY_TORCH_TENSORS
  145. )
  146. if self._zero_copy_tensors_enabled:
  147. try:
  148. import torch
  149. self._register_cloudpickle_reducer(
  150. torch.Tensor, tensor_serialization_utils.zero_copy_tensors_reducer
  151. )
  152. except ImportError:
  153. # Warn and disable zero-copy tensor serialization when PyTorch is missing,
  154. # even if RAY_ENABLE_ZERO_COPY_TORCH_TENSORS is set.
  155. warnings.warn(
  156. "PyTorch is not installed. Disabling zero-copy tensor serialization "
  157. "even though RAY_ENABLE_ZERO_COPY_TORCH_TENSORS is set.",
  158. tensor_serialization_utils.ZeroCopyTensorsWarning,
  159. stacklevel=3,
  160. )
  161. self._zero_copy_tensors_enabled = False
  162. def actor_handle_reducer(obj):
  163. ray._private.worker.global_worker.check_connected()
  164. serialized, actor_handle_id, weak_ref = obj._serialization_helper()
  165. # Update ref counting for the actor handle
  166. if not weak_ref:
  167. self.add_contained_object_ref(
  168. actor_handle_id,
  169. # Right now, so many tests are failing when this is set.
  170. # Allow it for now, but we should eventually disallow it here.
  171. allow_out_of_band_serialization=True,
  172. )
  173. return _actor_handle_deserializer, (serialized, weak_ref)
  174. self._register_cloudpickle_reducer(ray.actor.ActorHandle, actor_handle_reducer)
  175. def compiled_dag_ref_reducer(obj):
  176. raise TypeError("Serialization of CompiledDAGRef is not supported.")
  177. self._register_cloudpickle_reducer(CompiledDAGRef, compiled_dag_ref_reducer)
  178. def object_ref_reducer(obj):
  179. worker = ray._private.worker.global_worker
  180. worker.check_connected()
  181. self.add_contained_object_ref(
  182. obj,
  183. allow_out_of_band_serialization=(
  184. ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION
  185. ),
  186. call_site=obj.call_site(),
  187. )
  188. obj, owner_address, object_status = worker.core_worker.serialize_object_ref(
  189. obj
  190. )
  191. # Check if this is a GPU ObjectRef being serialized inside a collection
  192. if (
  193. self.is_in_band_serialization()
  194. and worker.gpu_object_manager.is_managed_object(obj.hex())
  195. ):
  196. gpu_object_manager = (
  197. ray._private.worker.global_worker.gpu_object_manager
  198. )
  199. gpu_object_meta = gpu_object_manager.get_gpu_object_metadata(obj.hex())
  200. if gpu_object_meta.tensor_transport_meta is None:
  201. raise NotImplementedError(
  202. f"Tensor transport metadata is not available for object id: {obj.hex()} at the time of borrowing. "
  203. "This is likely because the object you're trying to borrow an object that was not created on the "
  204. "owner (not through ray.put). This is not supported yet, see issue #59644 for more details."
  205. )
  206. return _gpu_object_ref_deserializer, (
  207. obj.binary(),
  208. obj.call_site(),
  209. owner_address,
  210. object_status,
  211. obj.tensor_transport(),
  212. gpu_object_meta,
  213. )
  214. return _object_ref_deserializer, (
  215. obj.binary(),
  216. obj.call_site(),
  217. owner_address,
  218. object_status,
  219. obj.tensor_transport(),
  220. )
  221. self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
  222. def object_ref_generator_reducer(obj):
  223. return DynamicObjectRefGenerator, (obj._refs,)
  224. self._register_cloudpickle_reducer(
  225. DynamicObjectRefGenerator, object_ref_generator_reducer
  226. )
  227. serialization_addons.apply(self)
  228. def _register_cloudpickle_reducer(self, cls, reducer):
  229. pickle.CloudPickler.dispatch[cls] = reducer
  230. def _unregister_cloudpickle_reducer(self, cls):
  231. pickle.CloudPickler.dispatch.pop(cls, None)
  232. def _register_cloudpickle_serializer(
  233. self, cls, custom_serializer, custom_deserializer
  234. ):
  235. def _CloudPicklerReducer(obj):
  236. return custom_deserializer, (custom_serializer(obj),)
  237. # construct a reducer
  238. pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
  239. def is_in_band_serialization(self):
  240. return getattr(self._thread_local, "in_band", False)
  241. def set_in_band_serialization(self):
  242. self._thread_local.in_band = True
  243. def set_out_of_band_serialization(self):
  244. self._thread_local.in_band = False
  245. def get_outer_object_ref(self):
  246. stack = getattr(self._thread_local, "object_ref_stack", [])
  247. return stack[-1] if stack else None
  248. def get_and_clear_contained_object_refs(self):
  249. if not hasattr(self._thread_local, "object_refs"):
  250. self._thread_local.object_refs = set()
  251. return set()
  252. object_refs = self._thread_local.object_refs
  253. self._thread_local.object_refs = set()
  254. return object_refs
  255. def add_contained_object_ref(
  256. self,
  257. object_ref: "ray.ObjectRef",
  258. *,
  259. allow_out_of_band_serialization: bool,
  260. call_site: Optional[str] = None,
  261. ):
  262. if self.is_in_band_serialization():
  263. # This object ref is being stored in an object. Add the ID to the
  264. # list of IDs contained in the object so that we keep the inner
  265. # object value alive as long as the outer object is in scope.
  266. if not hasattr(self._thread_local, "object_refs"):
  267. self._thread_local.object_refs = set()
  268. self._thread_local.object_refs.add(object_ref)
  269. else:
  270. if not allow_out_of_band_serialization:
  271. raise ray.exceptions.OufOfBandObjectRefSerializationException(
  272. f"It is not allowed to serialize ray.ObjectRef {object_ref.hex()}. "
  273. "If you want to allow serialization, "
  274. "set `RAY_allow_out_of_band_object_ref_serialization=1.` "
  275. "If you set the env var, the object is pinned forever in the "
  276. "lifetime of the worker process and can cause Ray object leaks. "
  277. "See the callsite and trace to find where the serialization "
  278. "occurs.\nCallsite: "
  279. f"{call_site or 'Disabled. Set RAY_record_ref_creation_sites=1'}"
  280. )
  281. else:
  282. # If this serialization is out-of-band (e.g., from a call to
  283. # cloudpickle directly or captured in a remote function/actor),
  284. # then pin the object for the lifetime of this worker by adding
  285. # a local reference that won't ever be removed.
  286. ray._private.worker.global_worker.core_worker.add_object_ref_reference(
  287. object_ref
  288. )
  289. def _deserialize_pickle5_data(
  290. self,
  291. data: Any,
  292. out_of_band_tensors: Optional[List["torch.Tensor"]],
  293. ) -> Any:
  294. """
  295. Args:
  296. data: The data to deserialize.
  297. out_of_band_tensors: Tensors that were sent out-of-band. If this is
  298. not None, then the serialized data will contain placeholders
  299. that need to be replaced with these tensors.
  300. Returns:
  301. Any: The deserialized object.
  302. """
  303. from ray.experimental.channel import ChannelContext
  304. ctx = ChannelContext.get_current().serialization_context
  305. enable_gpu_objects = out_of_band_tensors is not None
  306. if enable_gpu_objects:
  307. ctx.reset_out_of_band_tensors(out_of_band_tensors)
  308. try:
  309. in_band, buffers = unpack_pickle5_buffers(data)
  310. if len(buffers) > 0:
  311. obj = pickle.loads(in_band, buffers=buffers)
  312. else:
  313. obj = pickle.loads(in_band)
  314. # cloudpickle does not provide error types
  315. except pickle.pickle.PicklingError:
  316. raise DeserializationError()
  317. finally:
  318. if enable_gpu_objects:
  319. ctx.reset_out_of_band_tensors([])
  320. return obj
  321. def _deserialize_msgpack_data(
  322. self,
  323. data,
  324. metadata_fields,
  325. out_of_band_tensors: Optional[List["torch.Tensor"]] = None,
  326. ):
  327. msgpack_data, pickle5_data = split_buffer(data)
  328. if metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_PYTHON:
  329. python_objects = self._deserialize_pickle5_data(
  330. pickle5_data, out_of_band_tensors
  331. )
  332. else:
  333. python_objects = []
  334. try:
  335. def _python_deserializer(index):
  336. return python_objects[index]
  337. obj = MessagePackSerializer.loads(msgpack_data, _python_deserializer)
  338. except Exception:
  339. raise DeserializationError()
  340. return obj
  341. def _deserialize_error_info(self, data, metadata_fields):
  342. assert data
  343. pb_bytes = self._deserialize_msgpack_data(data, metadata_fields)
  344. assert pb_bytes
  345. ray_error_info = RayErrorInfo()
  346. ray_error_info.ParseFromString(pb_bytes)
  347. return ray_error_info
  348. def _deserialize_actor_died_error(self, data, metadata_fields):
  349. if not data:
  350. return ActorDiedError()
  351. ray_error_info = self._deserialize_error_info(data, metadata_fields)
  352. assert ray_error_info.HasField("actor_died_error")
  353. if ray_error_info.actor_died_error.HasField("creation_task_failure_context"):
  354. return RayError.from_ray_exception(
  355. ray_error_info.actor_died_error.creation_task_failure_context
  356. )
  357. else:
  358. assert ray_error_info.actor_died_error.HasField("actor_died_error_context")
  359. return ActorDiedError(
  360. cause=ray_error_info.actor_died_error.actor_died_error_context
  361. )
  362. def _deserialize_object(
  363. self,
  364. data,
  365. metadata,
  366. object_ref,
  367. out_of_band_tensors: Optional[List["torch.Tensor"]],
  368. ):
  369. if metadata:
  370. metadata_fields = metadata.split(b",")
  371. if metadata_fields[0] in [
  372. ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
  373. ray_constants.OBJECT_METADATA_TYPE_PYTHON,
  374. ]:
  375. return self._deserialize_msgpack_data(
  376. data, metadata_fields, out_of_band_tensors
  377. )
  378. # Check if the object should be returned as raw bytes.
  379. if metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_RAW:
  380. if data is None:
  381. return b""
  382. return data.to_pybytes()
  383. elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
  384. obj = self._deserialize_msgpack_data(
  385. data, metadata_fields, out_of_band_tensors
  386. )
  387. # The last character is a 1 if weak_ref=True and 0 else.
  388. serialized, weak_ref = obj[:-1], obj[-1:] == b"1"
  389. return _actor_handle_deserializer(serialized, weak_ref)
  390. # Otherwise, return an exception object based on
  391. # the error type.
  392. try:
  393. error_type = int(metadata_fields[0])
  394. except Exception:
  395. raise Exception(
  396. f"Can't deserialize object: {object_ref}, " f"metadata: {metadata}"
  397. )
  398. # RayTaskError is serialized with pickle5 in the data field.
  399. # TODO (kfstorm): exception serialization should be language
  400. # independent.
  401. if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
  402. obj = self._deserialize_msgpack_data(
  403. data, metadata_fields, out_of_band_tensors
  404. )
  405. return RayError.from_bytes(obj)
  406. elif error_type == ErrorType.Value("WORKER_DIED"):
  407. return WorkerCrashedError()
  408. elif error_type == ErrorType.Value("ACTOR_DIED"):
  409. return self._deserialize_actor_died_error(data, metadata_fields)
  410. elif error_type == ErrorType.Value("LOCAL_RAYLET_DIED"):
  411. return LocalRayletDiedError()
  412. elif error_type == ErrorType.Value("TASK_CANCELLED"):
  413. # Task cancellations are serialized in two ways, so check both
  414. # deserialization paths.
  415. # TODO(swang): We should only have one serialization path.
  416. try:
  417. # Deserialization from C++ (the CoreWorker task submitter).
  418. # The error info will be stored as a RayErrorInfo.
  419. error_message = ""
  420. if data:
  421. error_info = self._deserialize_error_info(data, metadata_fields)
  422. error_message = error_info.error_message
  423. return TaskCancelledError(error_message=error_message)
  424. except google.protobuf.message.DecodeError:
  425. # Deserialization from Python. The TaskCancelledError is
  426. # serialized and returned directly.
  427. obj = self._deserialize_msgpack_data(
  428. data, metadata_fields, out_of_band_tensors
  429. )
  430. return RayError.from_bytes(obj)
  431. elif error_type == ErrorType.Value("OBJECT_LOST"):
  432. return ObjectLostError(
  433. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  434. )
  435. elif error_type == ErrorType.Value("OBJECT_FETCH_TIMED_OUT"):
  436. return ObjectFetchTimedOutError(
  437. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  438. )
  439. elif error_type == ErrorType.Value("OUT_OF_DISK_ERROR"):
  440. return OutOfDiskError(
  441. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  442. )
  443. elif error_type == ErrorType.Value("OUT_OF_MEMORY"):
  444. error_info = self._deserialize_error_info(data, metadata_fields)
  445. return OutOfMemoryError(error_info.error_message)
  446. elif error_type == ErrorType.Value("NODE_DIED"):
  447. error_info = self._deserialize_error_info(data, metadata_fields)
  448. return NodeDiedError(error_info.error_message)
  449. elif error_type == ErrorType.Value("OBJECT_DELETED"):
  450. return ReferenceCountingAssertionError(
  451. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  452. )
  453. elif error_type == ErrorType.Value("OBJECT_FREED"):
  454. return ObjectFreedError(
  455. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  456. )
  457. elif error_type == ErrorType.Value("OWNER_DIED"):
  458. return OwnerDiedError(
  459. object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
  460. )
  461. elif error_type == ErrorType.Value("RUNTIME_ENV_SETUP_FAILED"):
  462. error_info = self._deserialize_error_info(data, metadata_fields)
  463. # TODO(sang): Assert instead once actor also reports error messages.
  464. error_msg = ""
  465. if error_info.HasField("runtime_env_setup_failed_error"):
  466. error_msg = error_info.runtime_env_setup_failed_error.error_message
  467. return RuntimeEnvSetupError(error_message=error_msg)
  468. elif error_type == ErrorType.Value("TASK_PLACEMENT_GROUP_REMOVED"):
  469. return TaskPlacementGroupRemoved()
  470. elif error_type == ErrorType.Value("ACTOR_PLACEMENT_GROUP_REMOVED"):
  471. return ActorPlacementGroupRemoved()
  472. elif error_type == ErrorType.Value("TASK_UNSCHEDULABLE_ERROR"):
  473. error_info = self._deserialize_error_info(data, metadata_fields)
  474. return TaskUnschedulableError(error_info.error_message)
  475. elif error_type == ErrorType.Value("ACTOR_UNSCHEDULABLE_ERROR"):
  476. error_info = self._deserialize_error_info(data, metadata_fields)
  477. return ActorUnschedulableError(error_info.error_message)
  478. elif error_type == ErrorType.Value("END_OF_STREAMING_GENERATOR"):
  479. return ObjectRefStreamEndOfStreamError()
  480. elif error_type == ErrorType.Value("ACTOR_UNAVAILABLE"):
  481. error_info = self._deserialize_error_info(data, metadata_fields)
  482. if error_info.HasField("actor_unavailable_error"):
  483. actor_id = error_info.actor_unavailable_error.actor_id
  484. else:
  485. actor_id = None
  486. return ActorUnavailableError(error_info.error_message, actor_id)
  487. elif ErrorType.Name(error_type).startswith("OBJECT_UNRECONSTRUCTABLE_"):
  488. return ObjectReconstructionFailedError(
  489. object_ref.hex(),
  490. reason=error_type,
  491. owner_address=object_ref.owner_address(),
  492. call_site=object_ref.call_site(),
  493. )
  494. else:
  495. return RaySystemError("Unrecognized error type " + str(error_type))
  496. elif data:
  497. raise ValueError("non-null object should always have metadata")
  498. else:
  499. # Object isn't available in plasma. This should never be returned
  500. # to the user. We should only reach this line if this object was
  501. # deserialized as part of a list, and another object in the list
  502. # throws an exception.
  503. return PlasmaObjectNotAvailable
  504. def deserialize_objects(
  505. self,
  506. serialized_ray_objects: List[SerializedRayObject],
  507. object_refs,
  508. gpu_objects: Dict[str, List["torch.Tensor"]],
  509. ):
  510. assert len(serialized_ray_objects) == len(object_refs)
  511. # initialize the thread-local field
  512. if not hasattr(self._thread_local, "object_ref_stack"):
  513. self._thread_local.object_ref_stack = []
  514. results = []
  515. for object_ref, (data, metadata, transport) in zip(
  516. object_refs, serialized_ray_objects
  517. ):
  518. try:
  519. # Push the object ref to the stack, so the object under
  520. # the object ref knows where it comes from.
  521. self._thread_local.object_ref_stack.append(object_ref)
  522. object_tensors = None
  523. if object_ref is not None:
  524. object_id = object_ref.hex()
  525. if object_id in gpu_objects:
  526. object_tensors = gpu_objects[object_id]
  527. obj = self._deserialize_object(
  528. data,
  529. metadata,
  530. object_ref,
  531. object_tensors,
  532. )
  533. except Exception as e:
  534. logger.exception(e)
  535. obj = RaySystemError(e, traceback.format_exc())
  536. finally:
  537. # Must clear ObjectRef to not hold a reference.
  538. if self._thread_local.object_ref_stack:
  539. self._thread_local.object_ref_stack.pop()
  540. results.append(obj)
  541. return results
  542. def _serialize_to_pickle5(self, metadata, value):
  543. writer = Pickle5Writer()
  544. # TODO(swang): Check that contained_object_refs is empty.
  545. try:
  546. self.set_in_band_serialization()
  547. inband = pickle.dumps(
  548. value, protocol=5, buffer_callback=writer.buffer_callback
  549. )
  550. except Exception as e:
  551. self.get_and_clear_contained_object_refs()
  552. raise e
  553. finally:
  554. self.set_out_of_band_serialization()
  555. return Pickle5SerializedObject(
  556. metadata, inband, writer, self.get_and_clear_contained_object_refs()
  557. )
  558. def _serialize_to_msgpack(self, value):
  559. # Only RayTaskError is possible to be serialized here. We don't
  560. # need to deal with other exception types here.
  561. contained_object_refs = []
  562. if isinstance(value, RayTaskError):
  563. if issubclass(value.cause.__class__, TaskCancelledError):
  564. # Handle task cancellation errors separately because we never
  565. # want to warn about tasks that were intentionally cancelled by
  566. # the user.
  567. metadata = str(ErrorType.Value("TASK_CANCELLED")).encode("ascii")
  568. value = value.to_bytes()
  569. else:
  570. metadata = str(ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode(
  571. "ascii"
  572. )
  573. value = value.to_bytes()
  574. elif isinstance(value, ray.actor.ActorHandle):
  575. # TODO(fyresone): ActorHandle should be serialized via the
  576. # custom type feature of cross-language.
  577. serialized, actor_handle_id, weak_ref = value._serialization_helper()
  578. if not weak_ref:
  579. contained_object_refs.append(actor_handle_id)
  580. # Update ref counting for the actor handle
  581. metadata = ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE
  582. # Append a 1 to mean weak ref or 0 for strong ref.
  583. # We do this here instead of in the main serialization helper
  584. # because msgpack expects a bytes object. We cannot serialize
  585. # `weak_ref` in the C++ code because the weak_ref property is only
  586. # available in the Python ActorHandle instance.
  587. value = serialized + (b"1" if weak_ref else b"0")
  588. else:
  589. metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
  590. python_objects = []
  591. def _python_serializer(o):
  592. index = len(python_objects)
  593. python_objects.append(o)
  594. return index
  595. msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
  596. if python_objects:
  597. metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
  598. pickle5_serialized_object = self._serialize_to_pickle5(
  599. metadata, python_objects
  600. )
  601. else:
  602. pickle5_serialized_object = None
  603. return MessagePackSerializedObject(
  604. metadata, msgpack_data, contained_object_refs, pickle5_serialized_object
  605. )
  606. def serialize_gpu_objects(
  607. self,
  608. value: Any,
  609. ) -> Tuple[MessagePackSerializedObject, List["torch.Tensor"]]:
  610. """Retrieve GPU data from `value` and store it in the GPU object store. Then, return the serialized value.
  611. Args:
  612. value: The value to serialize.
  613. Returns:
  614. Serialized value.
  615. """
  616. if not self._torch_custom_serializer_registered:
  617. # Register a custom serializer for torch.Tensor. If the method is
  618. # decorated with `@ray.method(tensor_transport="xxx")`, it will
  619. # use external transport (e.g. gloo, nccl, etc.) for tensor
  620. # communication between actors, instead of the normal serialize ->
  621. # object store -> deserialize codepath.
  622. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  623. TorchTensorType().register_custom_serializer()
  624. self._torch_custom_serializer_registered = True
  625. serialized_val, tensors = self._serialize_and_retrieve_tensors(value)
  626. return serialized_val, tensors
  627. def store_gpu_objects(
  628. self, obj_id: str, tensors: List["torch.Tensor"], tensor_transport: str
  629. ) -> bytes:
  630. """
  631. Store GPU objects in the GPU object store.
  632. Args:
  633. obj_id: The object ID of the value. `obj_id` is required, and the GPU data (e.g. tensors) in `value`
  634. will be stored in the GPU object store with the key `obj_id`.
  635. tensors: The tensors to store in the GPU object store.
  636. tensor_transport: The transport with which the RDT object will be transferred.
  637. Returns:
  638. The serialized tensor transport metadata
  639. """
  640. assert (
  641. obj_id is not None
  642. ), "`obj_id` is required, and it is the key to retrieve corresponding tensors from the GPU object store."
  643. # Regardless of whether `tensors` is empty, we always store the GPU object in
  644. # the GPU object store. This ensures that direct transport system tasks are not
  645. # blocked indefinitely.
  646. worker = ray._private.worker.global_worker
  647. gpu_object_manager = worker.gpu_object_manager
  648. tensor_transport_meta = gpu_object_manager.gpu_object_store.add_object_primary(
  649. obj_id, tensors, tensor_transport
  650. )
  651. return pickle.dumps(tensor_transport_meta)
  652. def serialize(
  653. self, value: Any
  654. ) -> Union[RawSerializedObject, MessagePackSerializedObject]:
  655. """Serialize an object.
  656. Args:
  657. value: The value to serialize.
  658. Returns:
  659. Serialized value.
  660. """
  661. if isinstance(value, bytes):
  662. # If the object is a byte array, skip serializing it and
  663. # use a special metadata to indicate it's raw binary. So
  664. # that this object can also be read by Java.
  665. return RawSerializedObject(value)
  666. else:
  667. return self._serialize_to_msgpack(value)
  668. def _serialize_and_retrieve_tensors(
  669. self, value: Any
  670. ) -> Tuple[MessagePackSerializedObject, List["torch.Tensor"]]:
  671. """
  672. Serialize `value` and return the serialized value and any tensors retrieved from `value`.
  673. This is only used for GPU objects.
  674. """
  675. from ray.experimental.channel import ChannelContext
  676. ctx = ChannelContext.get_current().serialization_context
  677. prev_use_external_transport = ctx.use_external_transport
  678. ctx.set_use_external_transport(True)
  679. try:
  680. serialized_val = self._serialize_to_msgpack(value)
  681. finally:
  682. ctx.set_use_external_transport(prev_use_external_transport)
  683. tensors, _ = ctx.reset_out_of_band_tensors([])
  684. return serialized_val, tensors