| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from enum import Enum
- from packaging.version import Version
- from ray.rllib.utils.checkpoints import try_import_msgpack
- from ray.util.annotations import DeveloperAPI
- msgpack = None
- @DeveloperAPI
- class RLlink(Enum):
- PROTOCOL_VERSION = Version("0.0.1")
- # Requests: Client (external env) -> Server (RLlib).
- # ----
- # Ping command (initial handshake).
- PING = "PING"
- # List of episodes (similar to what an EnvRunner.sample() call would return).
- EPISODES = "EPISODES"
- # Request state (e.g. model weights).
- GET_STATE = "GET_STATE"
- # Request Algorithm config.
- GET_CONFIG = "GET_CONFIG"
- # Send episodes and request the next state update right after that.
- # Clients sending this message should wait for a SET_STATE message as an immediate
- # response. Useful for external samplers that must collect on-policy data.
- EPISODES_AND_GET_STATE = "EPISODES_AND_GET_STATE"
- # Responses: Server (RLlib) -> Client (external env).
- # ----
- # Pong response (initial handshake).
- PONG = "PONG"
- # Set state (e.g. model weights).
- SET_STATE = "SET_STATE"
- # Set Algorithm config.
- SET_CONFIG = "SET_CONFIG"
- # @OldAPIStack (to be deprecated soon).
- ACTION_SPACE = "ACTION_SPACE"
- OBSERVATION_SPACE = "OBSERVATION_SPACE"
- GET_WORKER_ARGS = "GET_WORKER_ARGS"
- GET_WEIGHTS = "GET_WEIGHTS"
- REPORT_SAMPLES = "REPORT_SAMPLES"
- START_EPISODE = "START_EPISODE"
- GET_ACTION = "GET_ACTION"
- LOG_ACTION = "LOG_ACTION"
- LOG_RETURNS = "LOG_RETURNS"
- END_EPISODE = "END_EPISODE"
- def __str__(self):
- return self.name
- @DeveloperAPI
- def send_rllink_message(sock_, message: dict):
- """Sends a message to the client with a length header."""
- global msgpack
- if msgpack is None:
- msgpack = try_import_msgpack(error=True)
- body = msgpack.packb(message, use_bin_type=True) # .encode("utf-8")
- header = str(len(body)).zfill(8).encode("utf-8")
- try:
- sock_.sendall(header + body)
- except Exception as e:
- raise ConnectionError(
- f"Error sending message {message} to server on socket {sock_}! "
- f"Original error was: {e}"
- )
- @DeveloperAPI
- def get_rllink_message(sock_):
- """Receives a message from the client following the length-header protocol."""
- global msgpack
- if msgpack is None:
- msgpack = try_import_msgpack(error=True)
- try:
- # Read the length header (8 bytes)
- header = _get_num_bytes(sock_, 8)
- msg_length = int(header.decode("utf-8"))
- # Read the message body
- body = _get_num_bytes(sock_, msg_length)
- # Decode JSON.
- message = msgpack.unpackb(body, raw=False) # .loads(body.decode("utf-8"))
- # Check for proper protocol.
- if "type" not in message:
- raise ConnectionError(
- "Protocol Error! Message from peer does not contain `type` field."
- )
- return RLlink(message.pop("type")), message
- except Exception as e:
- raise ConnectionError(
- f"Error receiving message from peer on socket {sock_}! "
- f"Original error was: {e}"
- )
- def _get_num_bytes(sock_, num_bytes):
- """Helper function to receive a specific number of bytes."""
- data = b""
- while len(data) < num_bytes:
- packet = sock_.recv(num_bytes - len(data))
- if not packet:
- raise ConnectionError(f"No data received from socket {sock_}!")
- data += packet
- return data
|