lambdas.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from typing import Any, Callable, Dict, Type
  2. from ray.rllib.connectors.connector import (
  3. ActionConnector,
  4. ConnectorContext,
  5. )
  6. from ray.rllib.connectors.registry import register_connector
  7. from ray.rllib.utils.annotations import OldAPIStack
  8. from ray.rllib.utils.numpy import convert_to_numpy
  9. from ray.rllib.utils.typing import (
  10. ActionConnectorDataType,
  11. PolicyOutputType,
  12. StateBatches,
  13. TensorStructType,
  14. )
  15. @OldAPIStack
  16. def register_lambda_action_connector(
  17. name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
  18. ) -> Type[ActionConnector]:
  19. """A util to register any function transforming PolicyOutputType as an ActionConnector.
  20. The only requirement is that fn should take actions, states, and fetches as input,
  21. and return transformed actions, states, and fetches.
  22. Args:
  23. name: Name of the resulting actor connector.
  24. fn: The function that transforms PolicyOutputType.
  25. Returns:
  26. A new ActionConnector class that transforms PolicyOutputType using fn.
  27. """
  28. class LambdaActionConnector(ActionConnector):
  29. def transform(
  30. self, ac_data: ActionConnectorDataType
  31. ) -> ActionConnectorDataType:
  32. assert isinstance(
  33. ac_data.output, tuple
  34. ), "Action connector requires PolicyOutputType data."
  35. actions, states, fetches = ac_data.output
  36. return ActionConnectorDataType(
  37. ac_data.env_id,
  38. ac_data.agent_id,
  39. ac_data.input_dict,
  40. fn(actions, states, fetches),
  41. )
  42. def to_state(self):
  43. return name, None
  44. @staticmethod
  45. def from_state(ctx: ConnectorContext, params: Any):
  46. return LambdaActionConnector(ctx)
  47. LambdaActionConnector.__name__ = name
  48. LambdaActionConnector.__qualname__ = name
  49. register_connector(name, LambdaActionConnector)
  50. return LambdaActionConnector
  51. # Convert actions and states into numpy arrays if necessary.
  52. ConvertToNumpyConnector = OldAPIStack(
  53. register_lambda_action_connector(
  54. "ConvertToNumpyConnector",
  55. lambda actions, states, fetches: (
  56. convert_to_numpy(actions),
  57. convert_to_numpy(states),
  58. fetches,
  59. ),
  60. ),
  61. )