pipeline.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import logging
  2. from collections import defaultdict
  3. from typing import Any, List
  4. from ray.rllib.connectors.connector import (
  5. ActionConnector,
  6. Connector,
  7. ConnectorContext,
  8. ConnectorPipeline,
  9. )
  10. from ray.rllib.connectors.registry import get_connector, register_connector
  11. from ray.rllib.utils.annotations import OldAPIStack
  12. from ray.rllib.utils.typing import ActionConnectorDataType
  13. from ray.util.timer import _Timer
  14. logger = logging.getLogger(__name__)
  15. @OldAPIStack
  16. class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
  17. def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
  18. super().__init__(ctx, connectors)
  19. self.timers = defaultdict(_Timer)
  20. def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
  21. for c in self.connectors:
  22. timer = self.timers[str(c)]
  23. with timer:
  24. ac_data = c(ac_data)
  25. return ac_data
  26. def to_state(self):
  27. children = []
  28. for c in self.connectors:
  29. state = c.to_state()
  30. assert isinstance(state, tuple) and len(state) == 2, (
  31. "Serialized connector state must be in the format of "
  32. f"Tuple[name: str, params: Any]. Instead we got {state}"
  33. f"for connector {c.__name__}."
  34. )
  35. children.append(state)
  36. return ActionConnectorPipeline.__name__, children
  37. @staticmethod
  38. def from_state(ctx: ConnectorContext, params: Any):
  39. assert (
  40. type(params) is list
  41. ), "ActionConnectorPipeline takes a list of connector params."
  42. connectors = []
  43. for state in params:
  44. try:
  45. name, subparams = state
  46. connectors.append(get_connector(name, ctx, subparams))
  47. except Exception as e:
  48. logger.error(f"Failed to de-serialize connector state: {state}")
  49. raise e
  50. return ActionConnectorPipeline(ctx, connectors)
  51. register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)