from typing import Any, Callable, Type import numpy as np import tree # dm_tree from ray.rllib.connectors.connector import ( AgentConnector, ConnectorContext, ) from ray.rllib.connectors.registry import register_connector from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.typing import ( AgentConnectorDataType, AgentConnectorsOutput, ) @OldAPIStack def register_lambda_agent_connector( name: str, fn: Callable[[Any], Any] ) -> Type[AgentConnector]: """A util to register any simple transforming function as an AgentConnector The only requirement is that fn should take a single data object and return a single data object. Args: name: Name of the resulting actor connector. fn: The function that transforms env / agent data. Returns: A new AgentConnector class that transforms data using fn. """ class LambdaAgentConnector(AgentConnector): def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: return AgentConnectorDataType( ac_data.env_id, ac_data.agent_id, fn(ac_data.data) ) def to_state(self): return name, None @staticmethod def from_state(ctx: ConnectorContext, params: Any): return LambdaAgentConnector(ctx) LambdaAgentConnector.__name__ = name LambdaAgentConnector.__qualname__ = name register_connector(name, LambdaAgentConnector) return LambdaAgentConnector @OldAPIStack def flatten_data(data: AgentConnectorsOutput): assert isinstance( data, AgentConnectorsOutput ), "Single agent data must be of type AgentConnectorsOutput" raw_dict = data.raw_dict sample_batch = data.sample_batch flattened = {} for k, v in sample_batch.items(): if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"): # Do not flatten infos, actions, and state_out_ columns. flattened[k] = v continue if v is None: # Keep the same column shape. flattened[k] = None continue flattened[k] = np.array(tree.flatten(v)) flattened = SampleBatch(flattened, is_training=False) return AgentConnectorsOutput(raw_dict, flattened) # Agent connector to build and return a flattened observation SampleBatch # in addition to the original input dict. FlattenDataAgentConnector = OldAPIStack( register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data) )