lambdas.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Any, Callable, Type
  2. import numpy as np
  3. import tree # dm_tree
  4. from ray.rllib.connectors.connector import (
  5. AgentConnector,
  6. ConnectorContext,
  7. )
  8. from ray.rllib.connectors.registry import register_connector
  9. from ray.rllib.policy.sample_batch import SampleBatch
  10. from ray.rllib.utils.annotations import OldAPIStack
  11. from ray.rllib.utils.typing import (
  12. AgentConnectorDataType,
  13. AgentConnectorsOutput,
  14. )
  15. @OldAPIStack
  16. def register_lambda_agent_connector(
  17. name: str, fn: Callable[[Any], Any]
  18. ) -> Type[AgentConnector]:
  19. """A util to register any simple transforming function as an AgentConnector
  20. The only requirement is that fn should take a single data object and return
  21. a single data object.
  22. Args:
  23. name: Name of the resulting actor connector.
  24. fn: The function that transforms env / agent data.
  25. Returns:
  26. A new AgentConnector class that transforms data using fn.
  27. """
  28. class LambdaAgentConnector(AgentConnector):
  29. def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
  30. return AgentConnectorDataType(
  31. ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
  32. )
  33. def to_state(self):
  34. return name, None
  35. @staticmethod
  36. def from_state(ctx: ConnectorContext, params: Any):
  37. return LambdaAgentConnector(ctx)
  38. LambdaAgentConnector.__name__ = name
  39. LambdaAgentConnector.__qualname__ = name
  40. register_connector(name, LambdaAgentConnector)
  41. return LambdaAgentConnector
  42. @OldAPIStack
  43. def flatten_data(data: AgentConnectorsOutput):
  44. assert isinstance(
  45. data, AgentConnectorsOutput
  46. ), "Single agent data must be of type AgentConnectorsOutput"
  47. raw_dict = data.raw_dict
  48. sample_batch = data.sample_batch
  49. flattened = {}
  50. for k, v in sample_batch.items():
  51. if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
  52. # Do not flatten infos, actions, and state_out_ columns.
  53. flattened[k] = v
  54. continue
  55. if v is None:
  56. # Keep the same column shape.
  57. flattened[k] = None
  58. continue
  59. flattened[k] = np.array(tree.flatten(v))
  60. flattened = SampleBatch(flattened, is_training=False)
  61. return AgentConnectorsOutput(raw_dict, flattened)
  62. # Agent connector to build and return a flattened observation SampleBatch
  63. # in addition to the original input dict.
  64. FlattenDataAgentConnector = OldAPIStack(
  65. register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data)
  66. )