server_pickler.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """Implements the client side of the client/server pickling protocol.
  2. These picklers are aware of the server internals and can find the
  3. references held for the client within the server.
  4. More discussion about the client/server pickling protocol can be found in:
  5. ray/util/client/client_pickler.py
  6. ServerPickler dumps ray objects from the server into the appropriate stubs.
  7. ClientUnpickler loads stubs from the client and finds their associated handle
  8. in the server instance.
  9. """
  10. import io
  11. from typing import TYPE_CHECKING, Any
  12. import ray
  13. import ray.cloudpickle as cloudpickle
  14. from ray._private.client_mode_hook import disable_client_hook
  15. from ray.util.client.client_pickler import PickleStub
  16. from ray.util.client.server.server_stubs import (
  17. ClientReferenceActor,
  18. ClientReferenceFunction,
  19. )
  20. if TYPE_CHECKING:
  21. from ray.util.client.server.server import RayletServicer
  22. import pickle # noqa: F401
  23. class ServerPickler(cloudpickle.CloudPickler):
  24. def __init__(self, client_id: str, server: "RayletServicer", *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. self.client_id = client_id
  27. self.server = server
  28. def persistent_id(self, obj):
  29. if isinstance(obj, ray.ObjectRef):
  30. obj_id = obj.binary()
  31. if obj_id not in self.server.object_refs[self.client_id]:
  32. # We're passing back a reference, probably inside a reference.
  33. # Let's hold onto it.
  34. self.server.object_refs[self.client_id][obj_id] = obj
  35. return PickleStub(
  36. type="Object",
  37. client_id=self.client_id,
  38. ref_id=obj_id,
  39. name=None,
  40. baseline_options=None,
  41. )
  42. elif isinstance(obj, ray.actor.ActorHandle):
  43. actor_id = obj._actor_id.binary()
  44. if actor_id not in self.server.actor_refs:
  45. # We're passing back a handle, probably inside a reference.
  46. self.server.actor_refs[actor_id] = obj
  47. if actor_id not in self.server.actor_owners[self.client_id]:
  48. self.server.actor_owners[self.client_id].add(actor_id)
  49. return PickleStub(
  50. type="Actor",
  51. client_id=self.client_id,
  52. ref_id=obj._actor_id.binary(),
  53. name=None,
  54. baseline_options=None,
  55. )
  56. return None
  57. class ClientUnpickler(pickle.Unpickler):
  58. def __init__(self, server, *args, **kwargs):
  59. super().__init__(*args, **kwargs)
  60. self.server = server
  61. def persistent_load(self, pid):
  62. assert isinstance(pid, PickleStub)
  63. if pid.type == "Ray":
  64. return ray
  65. elif pid.type == "Object":
  66. return self.server.object_refs[pid.client_id][pid.ref_id]
  67. elif pid.type == "Actor":
  68. return self.server.actor_refs[pid.ref_id]
  69. elif pid.type == "RemoteFuncSelfReference":
  70. return ClientReferenceFunction(pid.client_id, pid.ref_id)
  71. elif pid.type == "RemoteFunc":
  72. return self.server.lookup_or_register_func(
  73. pid.ref_id, pid.client_id, pid.baseline_options
  74. )
  75. elif pid.type == "RemoteActorSelfReference":
  76. return ClientReferenceActor(pid.client_id, pid.ref_id)
  77. elif pid.type == "RemoteActor":
  78. return self.server.lookup_or_register_actor(
  79. pid.ref_id, pid.client_id, pid.baseline_options
  80. )
  81. elif pid.type == "RemoteMethod":
  82. actor = self.server.actor_refs[pid.ref_id]
  83. return getattr(actor, pid.name)
  84. else:
  85. raise NotImplementedError("Uncovered client data type")
  86. def dumps_from_server(
  87. obj: Any, client_id: str, server_instance: "RayletServicer", protocol=None
  88. ) -> bytes:
  89. with io.BytesIO() as file:
  90. sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
  91. sp.dump(obj)
  92. return file.getvalue()
  93. def loads_from_client(
  94. data: bytes,
  95. server_instance: "RayletServicer",
  96. *,
  97. fix_imports=True,
  98. encoding="ASCII",
  99. errors="strict"
  100. ) -> Any:
  101. with disable_client_hook():
  102. if isinstance(data, str):
  103. raise TypeError("Can't load pickle from unicode string")
  104. file = io.BytesIO(data)
  105. return ClientUnpickler(
  106. server_instance, file, fix_imports=fix_imports, encoding=encoding
  107. ).load()