client_pickler.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """Implements the client side of the client/server pickling protocol.
  2. All ray client client/server data transfer happens through this pickling
  3. protocol. The model is as follows:
  4. * All Client objects (eg ClientObjectRef) always live on the client and
  5. are never represented in the server
  6. * All Ray objects (eg, ray.ObjectRef) always live on the server and are
  7. never returned to the client
  8. * In order to translate between these two references, PickleStub tuples
  9. are generated as persistent ids in the data blobs during the pickling
  10. and unpickling of these objects.
  11. The PickleStubs have just enough information to find or generate their
  12. associated partner object on either side.
  13. This also has the advantage of avoiding predefined pickle behavior for ray
  14. objects, which may include ray internal reference counting.
  15. ClientPickler dumps things from the client into the appropriate stubs
  16. ServerUnpickler loads stubs from the server into their client counterparts.
  17. """
  18. import io
  19. import pickle # noqa: F401
  20. from typing import Any, Dict, NamedTuple, Optional
  21. import ray.cloudpickle as cloudpickle
  22. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  23. from ray.util.client import RayAPIStub
  24. from ray.util.client.common import (
  25. ClientActorClass,
  26. ClientActorHandle,
  27. ClientActorRef,
  28. ClientObjectRef,
  29. ClientRemoteFunc,
  30. ClientRemoteMethod,
  31. InProgressSentinel,
  32. OptionWrapper,
  33. )
  34. # NOTE(barakmich): These PickleStubs are really close to
  35. # the data for an execution, with no arguments. Combine the two?
  36. class PickleStub(
  37. NamedTuple(
  38. "PickleStub",
  39. [
  40. ("type", str),
  41. ("client_id", str),
  42. ("ref_id", bytes),
  43. ("name", Optional[str]),
  44. ("baseline_options", Optional[Dict]),
  45. ],
  46. )
  47. ):
  48. def __reduce__(self):
  49. # PySpark's namedtuple monkey patch breaks compatibility with
  50. # cloudpickle. Thus we revert this patch here if it exists.
  51. return object.__reduce__(self)
  52. class ClientPickler(cloudpickle.CloudPickler):
  53. def __init__(self, client_id, *args, **kwargs):
  54. super().__init__(*args, **kwargs)
  55. self.client_id = client_id
  56. def persistent_id(self, obj):
  57. if isinstance(obj, RayAPIStub):
  58. return PickleStub(
  59. type="Ray",
  60. client_id=self.client_id,
  61. ref_id=b"",
  62. name=None,
  63. baseline_options=None,
  64. )
  65. elif isinstance(obj, ClientObjectRef):
  66. return PickleStub(
  67. type="Object",
  68. client_id=self.client_id,
  69. ref_id=obj.id,
  70. name=None,
  71. baseline_options=None,
  72. )
  73. elif isinstance(obj, ClientActorHandle):
  74. return PickleStub(
  75. type="Actor",
  76. client_id=self.client_id,
  77. ref_id=obj._actor_id.id,
  78. name=None,
  79. baseline_options=None,
  80. )
  81. elif isinstance(obj, ClientRemoteFunc):
  82. if obj._ref is None:
  83. obj._ensure_ref()
  84. if type(obj._ref) is InProgressSentinel:
  85. return PickleStub(
  86. type="RemoteFuncSelfReference",
  87. client_id=self.client_id,
  88. ref_id=obj._client_side_ref.id,
  89. name=None,
  90. baseline_options=None,
  91. )
  92. return PickleStub(
  93. type="RemoteFunc",
  94. client_id=self.client_id,
  95. ref_id=obj._ref.id,
  96. name=None,
  97. baseline_options=obj._options,
  98. )
  99. elif isinstance(obj, ClientActorClass):
  100. if obj._ref is None:
  101. obj._ensure_ref()
  102. if type(obj._ref) is InProgressSentinel:
  103. return PickleStub(
  104. type="RemoteActorSelfReference",
  105. client_id=self.client_id,
  106. ref_id=obj._client_side_ref.id,
  107. name=None,
  108. baseline_options=None,
  109. )
  110. return PickleStub(
  111. type="RemoteActor",
  112. client_id=self.client_id,
  113. ref_id=obj._ref.id,
  114. name=None,
  115. baseline_options=obj._options,
  116. )
  117. elif isinstance(obj, ClientRemoteMethod):
  118. return PickleStub(
  119. type="RemoteMethod",
  120. client_id=self.client_id,
  121. ref_id=obj._actor_handle.actor_ref.id,
  122. name=obj._method_name,
  123. baseline_options=None,
  124. )
  125. elif isinstance(obj, OptionWrapper):
  126. raise NotImplementedError("Sending a partial option is unimplemented")
  127. return None
  128. class ServerUnpickler(pickle.Unpickler):
  129. def persistent_load(self, pid):
  130. assert isinstance(pid, PickleStub)
  131. if pid.type == "Object":
  132. return ClientObjectRef(pid.ref_id)
  133. elif pid.type == "Actor":
  134. return ClientActorHandle(ClientActorRef(pid.ref_id))
  135. else:
  136. raise NotImplementedError("Being passed back an unknown stub")
  137. def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
  138. with io.BytesIO() as file:
  139. cp = ClientPickler(client_id, file, protocol=protocol)
  140. cp.dump(obj)
  141. return file.getvalue()
  142. def loads_from_server(
  143. data: bytes, *, fix_imports=True, encoding="ASCII", errors="strict"
  144. ) -> Any:
  145. if isinstance(data, str):
  146. raise TypeError("Can't load pickle from unicode string")
  147. file = io.BytesIO(data)
  148. return ServerUnpickler(
  149. file, fix_imports=fix_imports, encoding=encoding, errors=errors
  150. ).load()
  151. def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
  152. out = ray_client_pb2.Arg()
  153. out.local = ray_client_pb2.Arg.Locality.INTERNED
  154. out.data = dumps_from_client(val, client_id)
  155. return out