pipeline.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging
  2. from collections import defaultdict
  3. from typing import Any, List
  4. from ray.rllib.connectors.connector import (
  5. AgentConnector,
  6. Connector,
  7. ConnectorContext,
  8. ConnectorPipeline,
  9. )
  10. from ray.rllib.connectors.registry import get_connector, register_connector
  11. from ray.rllib.utils.annotations import OldAPIStack
  12. from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
  13. from ray.util.timer import _Timer
  14. logger = logging.getLogger(__name__)
  15. @OldAPIStack
  16. class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
  17. def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
  18. super().__init__(ctx, connectors)
  19. self.timers = defaultdict(_Timer)
  20. def reset(self, env_id: str):
  21. for c in self.connectors:
  22. c.reset(env_id)
  23. def on_policy_output(self, output: ActionConnectorDataType):
  24. for c in self.connectors:
  25. c.on_policy_output(output)
  26. def __call__(
  27. self, acd_list: List[AgentConnectorDataType]
  28. ) -> List[AgentConnectorDataType]:
  29. ret = acd_list
  30. for c in self.connectors:
  31. timer = self.timers[str(c)]
  32. with timer:
  33. ret = c(ret)
  34. return ret
  35. def to_state(self):
  36. children = []
  37. for c in self.connectors:
  38. state = c.to_state()
  39. assert isinstance(state, tuple) and len(state) == 2, (
  40. "Serialized connector state must be in the format of "
  41. f"Tuple[name: str, params: Any]. Instead we got {state}"
  42. f"for connector {c.__name__}."
  43. )
  44. children.append(state)
  45. return AgentConnectorPipeline.__name__, children
  46. @staticmethod
  47. def from_state(ctx: ConnectorContext, params: List[Any]):
  48. assert (
  49. type(params) is list
  50. ), "AgentConnectorPipeline takes a list of connector params."
  51. connectors = []
  52. for state in params:
  53. try:
  54. name, subparams = state
  55. connectors.append(get_connector(name, ctx, subparams))
  56. except Exception as e:
  57. logger.error(f"Failed to de-serialize connector state: {state}")
  58. raise e
  59. return AgentConnectorPipeline(ctx, connectors)
  60. register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)