immutable.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Any
  2. import tree # pip install dm_tree
  3. from ray.rllib.connectors.connector import (
  4. ActionConnector,
  5. ConnectorContext,
  6. )
  7. from ray.rllib.connectors.registry import register_connector
  8. from ray.rllib.utils.annotations import OldAPIStack
  9. from ray.rllib.utils.numpy import make_action_immutable
  10. from ray.rllib.utils.typing import ActionConnectorDataType
  11. @OldAPIStack
  12. class ImmutableActionsConnector(ActionConnector):
  13. def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
  14. assert isinstance(
  15. ac_data.output, tuple
  16. ), "Action connector requires PolicyOutputType data."
  17. actions, states, fetches = ac_data.output
  18. tree.traverse(make_action_immutable, actions, top_down=False)
  19. return ActionConnectorDataType(
  20. ac_data.env_id,
  21. ac_data.agent_id,
  22. ac_data.input_dict,
  23. (actions, states, fetches),
  24. )
  25. def to_state(self):
  26. return ImmutableActionsConnector.__name__, None
  27. @staticmethod
  28. def from_state(ctx: ConnectorContext, params: Any):
  29. return ImmutableActionsConnector(ctx)
  30. register_connector(ImmutableActionsConnector.__name__, ImmutableActionsConnector)