rllink.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from enum import Enum
  2. from packaging.version import Version
  3. from ray.rllib.utils.checkpoints import try_import_msgpack
  4. from ray.util.annotations import DeveloperAPI
  5. msgpack = None
  6. @DeveloperAPI
  7. class RLlink(Enum):
  8. PROTOCOL_VERSION = Version("0.0.1")
  9. # Requests: Client (external env) -> Server (RLlib).
  10. # ----
  11. # Ping command (initial handshake).
  12. PING = "PING"
  13. # List of episodes (similar to what an EnvRunner.sample() call would return).
  14. EPISODES = "EPISODES"
  15. # Request state (e.g. model weights).
  16. GET_STATE = "GET_STATE"
  17. # Request Algorithm config.
  18. GET_CONFIG = "GET_CONFIG"
  19. # Send episodes and request the next state update right after that.
  20. # Clients sending this message should wait for a SET_STATE message as an immediate
  21. # response. Useful for external samplers that must collect on-policy data.
  22. EPISODES_AND_GET_STATE = "EPISODES_AND_GET_STATE"
  23. # Responses: Server (RLlib) -> Client (external env).
  24. # ----
  25. # Pong response (initial handshake).
  26. PONG = "PONG"
  27. # Set state (e.g. model weights).
  28. SET_STATE = "SET_STATE"
  29. # Set Algorithm config.
  30. SET_CONFIG = "SET_CONFIG"
  31. # @OldAPIStack (to be deprecated soon).
  32. ACTION_SPACE = "ACTION_SPACE"
  33. OBSERVATION_SPACE = "OBSERVATION_SPACE"
  34. GET_WORKER_ARGS = "GET_WORKER_ARGS"
  35. GET_WEIGHTS = "GET_WEIGHTS"
  36. REPORT_SAMPLES = "REPORT_SAMPLES"
  37. START_EPISODE = "START_EPISODE"
  38. GET_ACTION = "GET_ACTION"
  39. LOG_ACTION = "LOG_ACTION"
  40. LOG_RETURNS = "LOG_RETURNS"
  41. END_EPISODE = "END_EPISODE"
  42. def __str__(self):
  43. return self.name
  44. @DeveloperAPI
  45. def send_rllink_message(sock_, message: dict):
  46. """Sends a message to the client with a length header."""
  47. global msgpack
  48. if msgpack is None:
  49. msgpack = try_import_msgpack(error=True)
  50. body = msgpack.packb(message, use_bin_type=True) # .encode("utf-8")
  51. header = str(len(body)).zfill(8).encode("utf-8")
  52. try:
  53. sock_.sendall(header + body)
  54. except Exception as e:
  55. raise ConnectionError(
  56. f"Error sending message {message} to server on socket {sock_}! "
  57. f"Original error was: {e}"
  58. )
  59. @DeveloperAPI
  60. def get_rllink_message(sock_):
  61. """Receives a message from the client following the length-header protocol."""
  62. global msgpack
  63. if msgpack is None:
  64. msgpack = try_import_msgpack(error=True)
  65. try:
  66. # Read the length header (8 bytes)
  67. header = _get_num_bytes(sock_, 8)
  68. msg_length = int(header.decode("utf-8"))
  69. # Read the message body
  70. body = _get_num_bytes(sock_, msg_length)
  71. # Decode JSON.
  72. message = msgpack.unpackb(body, raw=False) # .loads(body.decode("utf-8"))
  73. # Check for proper protocol.
  74. if "type" not in message:
  75. raise ConnectionError(
  76. "Protocol Error! Message from peer does not contain `type` field."
  77. )
  78. return RLlink(message.pop("type")), message
  79. except Exception as e:
  80. raise ConnectionError(
  81. f"Error receiving message from peer on socket {sock_}! "
  82. f"Original error was: {e}"
  83. )
  84. def _get_num_bytes(sock_, num_bytes):
  85. """Helper function to receive a specific number of bytes."""
  86. data = b""
  87. while len(data) < num_bytes:
  88. packet = sock_.recv(num_bytes - len(data))
  89. if not packet:
  90. raise ConnectionError(f"No data received from socket {sock_}!")
  91. data += packet
  92. return data